Merge branch 'develop' into LWPCK-3549-cleanups

This commit is contained in:
SamiAario-AMD
2026-01-14 10:59:05 +02:00
committed by GitHub
210 changed files with 12028 additions and 2713 deletions

View File

@@ -43,6 +43,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
* Added reduce and multi reduction kernels
### Changed

62
Jenkinsfile vendored
View File

@@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){
def setup_cmd
def build_cmd
def execute_cmd = conf.get("execute_cmd", "")
//check the node gpu architecture
def arch_name = check_arch_name()
if(!setup_args.contains("NO_CK_BUILD")){
if (params.NINJA_BUILD_TRACE) {
echo "running ninja build trace"
@@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json"
archiveArtifacts "ck_build_trace_${check_arch_name()}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json"
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
archiveArtifacts "ck_build_trace_${arch_name}.json"
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
if (params.NINJA_FTIME_TRACE) {
echo "running ClangBuildAnalyzer"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log"
archiveArtifacts "clang_build_analysis_${check_arch_name()}.log"
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
archiveArtifacts "clang_build_analysis_${arch_name}.log"
}
@@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
if(params.BUILD_INSTANCES_ONLY){
@@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){
if(params.BUILD_PACKAGES){
echo "Build ckProfiler packages"
sh 'ninja -j64 package'
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
}
}
}
}
}
//check the node gpu architecture
def arch_name = check_arch_name()
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_*.log"
@@ -1201,8 +1201,8 @@ pipeline {
description: "Run the ck_tile FMHA tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_BASIC_TESTS",
defaultValue: false,
description: "Run the tile_engine_basic tests (default: OFF)")
defaultValue: true,
description: "Run the tile_engine_basic tests (default: ON)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
@@ -1650,7 +1650,10 @@ pipeline {
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
@@ -1667,37 +1670,6 @@ pipeline {
}
parallel
{
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a" \
-D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
-D GEMM_STREAMK_LAYOUT="rcr" \
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
{
when {

View File

@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)

View File

@@ -0,0 +1,76 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <ck/utility/data_type.hpp>
#include <ck/utility/tuple.hpp>
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType, DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout, DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr int NumDs = 2;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
// clang-format on
#include "run_grouped_gemm_multiple_d_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -71,339 +71,6 @@ using DeviceGemmInstance =
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
#include "run_grouped_gemm_multiple_d_example.inc"
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape
std::vector<GemmDesc> gemm_descs;
std::vector<KernelArguments> ggemm_kargs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDs>> p_Ds = {};
gemm_descs.reserve(group_count);
ggemm_kargs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
d_tensors_device.resize(group_count); // reserve and update vector size
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
// The device op does not have to know M problem size at lunch time.
gemm_descs.push_back({0,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
ggemm_kargs.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
ggemm_kargs.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = ggemm_kargs[i];
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 10)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return !run_grouped_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on

View File

@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on

View File

@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: async hargs (0=no, 1=yes)\n");
printf("arg5: group count (default=16)\n");
#if defined(EXAMPLE_USE_SPLITK)
printf("arg6: k-batch count (default=1)\n");

View File

@@ -0,0 +1,341 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<std::vector<ck::index_t>> stride_Ds;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape
std::vector<GemmDesc> gemm_descs;
std::vector<KernelArguments> ggemm_kargs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
std::vector<std::array<const void*, NumDs>> p_Ds = {};
gemm_descs.reserve(group_count);
ggemm_kargs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
p_Ds.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_result_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_result_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
d_tensors_device.resize(group_count); // reserve and update vector size
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
d_tensors.push_back(d_tens);
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
}
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
for(int j = 0; j < NumDs; ++j)
{
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
}
}
}
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
}
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
for(int j = 0; j < NumDs; ++j)
{
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
}
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Ds.push_back(
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
// The device op does not have to know M problem size at lunch time.
gemm_descs.push_back({0,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
ggemm_kargs.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
ggemm_kargs.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
auto karg = ggemm_kargs[i];
auto dev_res_tensor =
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
d_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
cde_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
bool run_grouped_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc < 10)
{
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(252);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
problem_size.stride_Ds.push_back({});
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
}
}
std::cout
<< "Usage:\n"
<< "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=n0, 1=yes)\n"
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "... setting default values." << std::endl;
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.Ms = argToIntArray(argv[4]);
problem_size.Ns = argToIntArray(argv[5]);
problem_size.Ks = argToIntArray(argv[6]);
problem_size.stride_As = argToIntArray(argv[7]);
problem_size.stride_Bs = argToIntArray(argv[8]);
problem_size.stride_Cs = argToIntArray(argv[9]);
for(int j = 0; j < NumDs; ++j)
{
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
}
problem_size.group_count = problem_size.Ms.size();
}
return run_grouped_gemm(problem_size, config);
}

View File

@@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t MPerBlock = 64;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
@@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
@@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
#endif
// clang-format on
@@ -182,12 +184,14 @@ int main(int argc, char* argv[])
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t N = 1536;
ck::index_t K = 4096;
// ck::index_t N = 4096;
// ck::index_t K = 6144;
// ck::index_t N = 128;
// ck::index_t K = 512;
ck::index_t experts = 8;
ck::index_t topk = 2;
ck::index_t experts = 16;
ck::index_t topk = 8;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 208;
@@ -196,9 +200,9 @@ int main(int argc, char* argv[])
// ck::index_t sorted_tile_num = 259;
// ck::index_t valid_tile_num = 256;
// ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 2;
ck::index_t valid_tile_num = 2;
ck::index_t tokens = 32;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 16;
ck::index_t tokens = 4;
#else
// deepseek
ck::index_t N = 2048;
@@ -209,7 +213,7 @@ int main(int argc, char* argv[])
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
ck::index_t KBatch = 6;
ck::index_t KBatch = 1;
if(argc == 1)
{
// use default case

View File

@@ -36,7 +36,7 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {
@@ -737,6 +737,8 @@ def get_fwd_blobs(
# Generate kernels for both page_size=16 and page_size=1024
for page_size in SUPPORTED_PAGE_SIZE:
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
continue
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,

View File

@@ -1351,8 +1351,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o_host});
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o_host});
else if constexpr(supports_qscale)
return ck_tile::scales{scale_o_host};
else

View File

@@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
# Multi Reduce Threadwise Example
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})
# Multi Reduce Blockwise Example
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -0,0 +1,271 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "32", "n dimension")
.insert("h", "19", "h dimension")
.insert("w", "7", "w dimension")
.insert("c", "512", "c dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using ComputeDataType = float;
using YDataType = float;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
// Validate input dimensions
const ck_tile::index_t kept_dim_len_prod = N * C;
const ck_tile::index_t reduce_total_length = H * W;
if(kept_dim_len_prod == 0)
{
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
std::cerr << "This will result in an empty output tensor." << std::endl;
return false;
}
if(reduce_total_length == 0)
{
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
<< ", product=" << reduce_total_length << ")." << std::endl;
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
return false;
}
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
std::vector<ck_tile::index_t> strides(4);
strides[0] = H * W * C;
strides[1] = W * C;
strides[2] = C;
strides[3] = 1;
// Define reduction specification:
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
const auto number_operations = y_host_dev_tuple.size();
std::vector<YDataType> h(number_operations * N * C);
auto y_buf_size = number_operations *
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
ck_tile::DeviceMem y_buf(y_buf_size);
const auto output_tensor_offset = N * C;
// Operations: one doing a sum reduction, the other computing the mean square
// In the case of mean square:
// 1. The element wise operation squares each element before reduction
// 2. The reduction operation sum the squared element
// 3. The accumulator element wise operation divides the result by the total number of reduced
// elements (intra block operation)
// 4. The partial result is updated across blocks using inter block reduction, a sum.
auto reduce_ops =
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnarySquare{}); // Elementwise
// ops
auto accumulator_elementwise_ops = ck_tile::make_tuple(
ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnaryDivide{
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
auto inter_block_reduce_ops = ck_tile::make_tuple(
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
constexpr ck_tile::index_t kBlockPerCu = 1;
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem = ck_tile::Reduce2dProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
decltype(reduce_ops),
decltype(kept_dim),
decltype(reduce_dims),
4>;
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
// Determine block group size for multi-block reduction
// block_group_size records how many blocks participate to a reduction (input data dependent)
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
// to process the whole reduction, each thread will to process multiple thread tile
// a num_block_tile_iterations times
auto [num_block_tile_iterations, block_group_size] =
typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams();
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
ck_tile::index_t kGridSize =
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;
std::cout << "Block group size: " << block_group_size
<< ", Num block tile iterations: " << num_block_tile_iterations
<< ", Reduce total length: " << reduce_total_length << std::endl;
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
if(!Kernel::IsSupportedArgument(
C, input_strides)) // output tensor's continuous dimension and input strides
{
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
// Init the output data with identity values respective to each reduce op
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
constexpr auto op = reduce_ops.at(i);
const auto identity_val = op.template GetIdentityValue<YDataType>();
const auto output_number_elements = N * C;
std::fill(h.begin() + i * output_number_elements,
h.begin() + (i + 1) * output_number_elements,
identity_val);
});
auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };
float ave_time = launch_kernel_time_mask(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
clear_output_buffer,
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims,
output_tensor_offset,
elementwise_ops,
accumulator_elementwise_ops,
inter_block_reduce_ops)
);
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
// reference
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
x_host,
y_host_ref_tuple,
reduce_ops,
kept_dim,
reduce_dims,
elementwise_ops,
accumulator_elementwise_ops,
inter_block_reduce_ops,
block_group_size);
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
// Transfer data from device and check error for each operation
y_buf.FromDevice(h.data());
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
h.data() + i * output_tensor_offset,
output_tensor_offset * sizeof(YDataType));
std::cout << "Checking operation " << i << ": " << std::endl;
bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
y_host_ref_tuple.get(ck_tile::number<i>{}));
if(pass_op)
{
std::cout << "✅ valid results for this operation" << std::endl;
}
pass &= pass_op;
});
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
}

View File

@@ -0,0 +1,224 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "32", "n dimension")
.insert("h", "7", "h dimension")
.insert("w", "7", "w dimension")
.insert("c", "512", "c dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "multi_reduce.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using ComputeDataType = float;
using YDataType = DataType;
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
// Validate input dimensions
const ck_tile::index_t kept_dim_len_prod = N * C;
const ck_tile::index_t reduce_total_length = H * W;
if(kept_dim_len_prod == 0)
{
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
std::cerr << "This will result in an empty output tensor." << std::endl;
return false;
}
if(reduce_total_length == 0)
{
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
<< ", product=" << reduce_total_length << ")." << std::endl;
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
return false;
}
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
std::vector<ck_tile::index_t> strides(4);
strides[0] = H * W * C;
strides[1] = W * C;
strides[2] = C;
strides[3] = 1;
// Define reduction specification:
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
const auto number_operations = y_host_dev_tuple.size();
// Two operations: one do a sum reduction, the other computing the mean square
auto reduce_ops =
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops
auto elementwise_ops =
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnarySquare{}); // Elementwise ops
auto accumulator_elementwise_ops =
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnaryDivide{
reduce_total_length}); // Accumulator Elementiwise ops on reduction,
auto y_buf_size = number_operations *
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
ck_tile::DeviceMem y_buf(y_buf_size);
const auto output_tensor_offset = N * C;
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
BlockTile::at(ck_tile::number<0>{});
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem = ck_tile::Reduce2dProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
decltype(reduce_ops),
decltype(kept_dim),
decltype(reduce_dims),
4>;
using Kernel = ck_tile::MultiReduceThreadWise<Problem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
if(!Kernel::IsSupportedArgument(
C, input_strides)) // output tensor's continuous dimension and input strides
{
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims,
output_tensor_offset,
elementwise_ops,
accumulator_elementwise_ops));
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
std::vector<YDataType> h(number_operations * N * C);
// reference
ck_tile::reference_multiple_reduce<XDataType, ComputeDataType, YDataType>(
x_host,
y_host_ref_tuple,
reduce_ops,
kept_dim,
reduce_dims,
elementwise_ops,
accumulator_elementwise_ops);
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
// Transfer data from device and check error for each operation
y_buf.FromDevice(h.data());
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
h.data() + i * output_tensor_offset,
output_tensor_offset * sizeof(YDataType));
pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
y_host_ref_tuple.get(ck_tile::number<i>{}));
});
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
}

View File

@@ -45,6 +45,11 @@ cmake
..
```
Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g.
`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively,
one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use
the instances from builder, one needs an actual Navi card.
## Building and Testing
The builder test suite is organized into two main categories:

View File

@@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
};
```
**Properties:**
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
- **`direction`**: Operation type (optional, defaults to FORWARD)
- **`direction`**: Operation type (Optional, defaults to FORWARD)
- `FORWARD`: Standard forward convolution
- `BACKWARD_DATA`: Gradient computation w.r.t. input
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)
- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)
- **`accumulation_data_type`**: Type used for internal accumulation
#### 2. Tensor Level
@@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) {
A tensor descriptor encapsulates:
- **Configuration**: Layout and data type information
- **Operation** (optional): Fused elementwise operations on this tensor
- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
#### 3. Tensor Configuration
@@ -126,7 +128,7 @@ Describes the memory layout and data types:
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<ConvLayout>;
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
};
```

View File

@@ -15,29 +15,31 @@ namespace ck_tile::builder {
/* Descriptors for individual elements of the algorithm description */
/********************************************************************/
// Common concept for size-related fields
template <typename T>
concept SizeType = std::unsigned_integral<std::remove_cvref_t<T>>;
// Concept for thread block dimensions for a GEMM problem.
template <typename T>
concept ThreadBlockDescriptor = requires(T t) {
{ t.block_size } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
{ t.block_size } -> SizeType;
{ t.tile_size.m } -> SizeType;
{ t.tile_size.n } -> SizeType;
{ t.tile_size.k } -> SizeType;
};
// Concept for parameters that describe a gridwise XDL GEMM problem.
template <typename T>
concept GridwiseXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> std::convertible_to<size_t>;
{ t.n_per_xdl } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> SizeType;
{ t.n_per_xdl } -> SizeType;
{ t.m_xdl_per_wave } -> SizeType;
{ t.n_xdl_per_wave } -> SizeType;
};
// Concept for parameter that describe block GEMM problem.
template <typename T>
concept BlockGemmDescriptor = requires(T t) {
concept BlockGemmPipelineDescriptor = requires(T t) {
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
@@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) {
// Concept for parameters that describe a gridwise WMMA GEMM problem.
template <typename T>
concept GridwiseWmmaGemmDescriptor = requires(T t) {
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m_per_wmma } -> std::convertible_to<size_t>;
{ t.n_per_wmma } -> std::convertible_to<size_t>;
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.k1 } -> SizeType;
{ t.m_per_wmma } -> SizeType;
{ t.n_per_wmma } -> SizeType;
{ t.m_wmma_per_wave } -> SizeType;
{ t.n_wmma_per_wave } -> SizeType;
};
// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
{ t.k0 } -> std::convertible_to<size_t>;
{ t.m_n } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
concept BlockTransferDescriptor3D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
};
template <typename T>
concept BlockTransferDescriptor4D = requires(T t) {
{ t.k0 } -> SizeType;
{ t.m_n } -> SizeType;
{ t.k1 } -> SizeType;
{ t.k_batch_size } -> SizeType;
};
template <typename T, size_t ThreadClusterRank>
concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D<T>) ||
(ThreadClusterRank == 4 && BlockTransferDescriptor4D<T>);
// Concept for thread cluster dimensions for GEMM output tensor.
template <typename T>
concept ThreadClusterDescriptor = requires(T t) {
{ t.m_block } -> std::convertible_to<size_t>;
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.n_block } -> std::convertible_to<size_t>;
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.m_block } -> SizeType;
{ t.m_wave_per_xdl } -> SizeType;
{ t.n_block } -> SizeType;
{ t.n_wave_per_xdl } -> SizeType;
};
// Concept for the LDS transfer for the convolution input tensors.
template <typename T>
concept LdsTransferDescriptor = requires(T t) {
{ t.src_vector_dim } -> std::convertible_to<size_t>;
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.src_vector_dim } -> SizeType;
{ t.src_scalar_per_vector } -> SizeType;
{ t.lds_dst_scalar_per_vector } -> SizeType;
{ t.is_direct_load } -> std::convertible_to<bool>;
{ t.lds_padding } -> std::convertible_to<bool>;
};
@@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) {
// LDS).
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
{ t.n_per_wave_per_shuffle } -> SizeType;
{ t.scalar_per_vector } -> SizeType;
};
// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
} || requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> SizeType;
{ t.tile_size.n } -> SizeType;
{ t.tile_size.k } -> SizeType;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.a_scalar_per_vector } -> SizeType;
{ t.b_scalar_per_vector } -> SizeType;
{ t.c_scalar_per_vector } -> SizeType;
};
// Concept to check if struct specifies block GEMM (CK Tile).
@@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires {
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
concept GridwiseFwdXdlGemmDescriptor = requires(T t) {
{ t.ak1 } -> SizeType;
{ t.bk1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
{ t.k1 } -> SizeType;
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
};
// Concept to check if a struct specifies gridwise WMMA GEMM info.
template <typename T>
concept SpecifiesGridwiseWmmaGemm = requires {
{ T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.
template <typename T>
template <typename T, size_t ThreadClusterRank = 3>
concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor;
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.a_scalar_per_vector } -> SizeType;
{ T::transfer.b_scalar_per_vector } -> SizeType;
{ T::transfer.c_scalar_per_vector } -> SizeType;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
@@ -210,8 +246,12 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
// Concept to check if struct specifies block GEMM.
template <typename T>
concept SpecifiesBlockGemm = requires {
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
{ T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor;
};
template <typename T>
concept SpecifiesGridwiseGemmPipeline = requires {
{ T::pipeline_version } -> std::convertible_to<PipelineVersion>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
@@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires {
template <typename T>
concept SpecifiesFwdConvSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
{ T::fwd_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
concept SpecifiesBwdWeightConvSpecialization = requires {
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
};
template <typename T>
@@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires {
template <typename T>
concept SpecifiesNumPrefetchStages = requires {
{ T::num_gemm_k_prefetch_stages } -> std::convertible_to<size_t>;
{ T::num_gemm_k_prefetch_stages } -> SizeType;
};
template <typename T>
concept SpecifiesNumGroupsToMerge = requires {
{ T::num_groups_to_merge } -> std::convertible_to<size_t>;
{ T::num_conv_groups_to_merge } -> SizeType;
};
template <typename T>
@@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires {
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
};
template <typename T>
concept SpecifiesGenericInstance = !requires {
{ T::specialization };
};
template <typename T>
concept SpecifiesTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
};
template <typename T>
concept HasTransposeTransfer = requires {
{ T::max_transpose_transfer_src_scalar_per_vector };
{ T::max_transpose_transfer_dst_scalar_per_vector };
};
template <typename T>
concept TransposeTransferWellDefinedIfProvided =
!HasTransposeTransfer<T> || SpecifiesTransposeTransfer<T>;
template <typename T>
concept SpecifiesGemmBatchOptions = requires {
{ T::num_conv_groups_to_merge } -> SizeType;
};
/******************************************** */
/* Algorithm specialization concepts */
/******************************************** */
template <typename T>
concept SpecifiesLargeTensorSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
};
template <typename T>
concept SpecifiesReferenceAlgorithm = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
};
template <typename T>
concept SpecifiesTwoStageSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
};
template <typename T>
concept SpecifiesMultipleDSupport = requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D;
};
/******************************************** */
/* DL-specific descriptors and requirements */
/******************************************** */
@@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires {
// Concept for DL thread configuration
template <typename T>
concept DlThreadConfigDescriptor = requires(T t) {
{ t.k0_per_block } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
{ t.m1_per_thread } -> std::convertible_to<size_t>;
{ t.n1_per_thread } -> std::convertible_to<size_t>;
{ t.k_per_thread } -> std::convertible_to<size_t>;
{ t.k0_per_block } -> SizeType;
{ t.k1 } -> SizeType;
{ t.m1_per_thread } -> SizeType;
{ t.n1_per_thread } -> SizeType;
{ t.k_per_thread } -> SizeType;
};
// Concept for DL thread cluster
@@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) {
};
// Concept for DL block transfer
template <typename T>
template <typename T, size_t N>
concept DlBlockTransferDescriptor = requires(T t) {
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_access_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, N>>;
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
};
template <typename T>
concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor<T, 4>;
template <typename T>
concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor<T, 5>;
// Concept for DL epilogue
template <typename T>
concept DlEpilogueDescriptor = requires(T t) {
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.src_dst_vector_dim } -> SizeType;
{ t.dst_scalar_per_vector } -> SizeType;
};
// Concept to check if algorithm specifies DL thread config
@@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires {
// Concept to check if algorithm specifies DL block transfer
template <typename T>
concept SpecifiesDlBlockTransfer = requires {
{ T::transfer.a.block_transfer } -> DlBlockTransferDescriptor;
{ T::transfer.b.block_transfer } -> DlBlockTransferDescriptor;
concept SpecifiesDlFwdBlockTransfer = requires {
{ T::transfer.a } -> DlBlockTransferDescriptor4D;
{ T::transfer.b } -> DlBlockTransferDescriptor4D;
};
template <typename T>
concept SpecifiesDlBwdBlockTransfer = requires {
{ T::transfer.a } -> DlBlockTransferDescriptor5D;
{ T::transfer.b } -> DlBlockTransferDescriptor5D;
};
// Concept to check if algorithm specifies DL C thread transfer
template <typename T>
concept SpecifiesDlEpilogue = requires {
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
{ T::transfer.c } -> DlEpilogueDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -29,10 +29,20 @@ concept OutputVectorTransferLimits = requires {
// Limits for access order. Must be a permutation of {0, 1, 2}.
template <auto Value>
concept AccessOrderLimits = requires {
concept AccessOrderLimits3D = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
(Value[2] >= 0 && Value[2] < 3));
(Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3));
};
// Limits for access order. Must be a permutation of {0, 1, 2, 3}.
template <auto Value>
concept AccessOrderLimits4D = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) &&
(Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
(Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) &&
(Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) &&
(Value.Size() == 4));
};
} // namespace ck_tile::builder

View File

@@ -80,6 +80,7 @@ concept ConvOutputLayout3D =
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
namespace detail {
template <typename T>
concept HasDataType = requires(T t) {
{ t.data_type };
@@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) {
};
};
} // namespace detail
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
};
template <typename T>
@@ -116,7 +118,6 @@ template <typename T, std::size_t N>
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
{
};
} // namespace detail
template <typename T>
concept ConvertibleToArrayOfTensorConfigs =
@@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
};
};
} // namespace detail
template <typename T>
concept TensorOperatorDescriptor = requires(T t) {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
};
template <typename T>
@@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) {
{ t.operation };
};
namespace detail {
template <typename T>
concept HasConvolutionDirection = requires(T t) {
{ t.direction };
@@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
};
};
} // namespace detail
// Concept for the convolution tensor
template <typename T>
concept ConvTensorDescriptor = requires(T t) {
{ t.config } -> TensorConfigDescriptor;
requires ElementwiseOpWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};
template <typename T>
@@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) {
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::ConvolutionDirectionWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};
// Concept to validate a convolution signature's values.
@@ -221,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim =
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
// Constraint for 3D conv signature.
template <auto Sig>
concept Is3D = requires {
requires Sig.spatial_dim == 3;
requires ConvInputLayout3D<Sig.input.config.layout>;
requires ConvOutputLayout3D<Sig.output.config.layout>;
requires ConvWeightLayout3D<Sig.weight.config.layout>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,128 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory {
// Base algorithm concepts
template <typename T, size_t ThreadClusterRank = 3>
concept TileTransferParameters =
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
template <typename T>
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
template <typename T>
concept SpecifiesTileTransferParameters4D = TileTransferParameters<T, 4>;
template <typename T>
concept FwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
template <typename T>
concept BwdXdlAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
template <typename T>
concept BwdWmmaV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
// Reference algorithm concept
template <typename T>
concept ReferenceAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesReferenceAlgorithm<T>;
// Tile-based algorithm concept
template <typename T>
concept TileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// FWD XDL algorithm concepts
template <typename T>
concept FwdXdlAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept LargeTensorAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesLargeTensorSupport<T>;
template <typename T>
concept FwdXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// FWD WMMA algorithm concepts
template <typename T>
concept FwdWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T>;
// FWD DL algorithms
template <typename T>
concept FwdDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlFwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// BWD weight XDL algorithm concepts
template <typename T>
concept BwdXdlAlgorithm =
BwdXdlAlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
// BWD weight WMMA algorithm concepts
template <typename T>
concept BwdWmmaAlgorithm =
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
template <typename T>
concept BwdWmmaV3Algorithm =
BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
template <typename T>
concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
// BWD weigth DL algorithms
template <typename T>
concept BwdDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesBwdWeightConvSpecialization<T> && SpecifiesDlThreadConfig<T> &&
SpecifiesDlThreadCluster<T> && SpecifiesDlBwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Dl instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
// DL-specific parameters from algorithm descriptor
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
// Thread cluster from descriptor
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
using ABlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
using ABlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
using BBlockTransferThreadClusterArrangeOrder =
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
using BBlockTransferSrcVectorTensorContiguousDimOrder =
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
DL_C_TRANSFER.dst_scalar_per_vector;
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,110 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
struct ConvBwdWeightMultiDWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Layouts::DsLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Types::DsDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,103 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightMultiDXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Layouts::DsLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Types::DsDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,111 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightTwoStageWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
ALGORITHM.num_conv_groups_to_merge,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,111 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightTwoStageXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
ALGORITHM.num_conv_groups_to_merge,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
struct ConvBwdWeightWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
ALGORITHM.num_gemm_k_prefetch_stages,
LOOP_SCHEDULER,
GRIDWISE_GEMM_PIPELINE_VERSION>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,109 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,103 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::OutComputeType,
typename Types::InComputeType,
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,108 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance
// of a grouped bwd weight convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct ConvBwdWeightXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid A thread cluster access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
"Invalid B thread cluster access order");
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
"Invalid A source access order");
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
SPATIAL_DIM,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
BWD_CONV_SPECIALIZATION,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::OutComputeType,
typename Types::InComputeType>;
};
} // namespace ck_tile::builder::factory

View File

@@ -57,6 +57,9 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
// Compile time diagnostics
#include "ck_tile/builder/factory/conv_algorithms.hpp"
// Include all factory implementations
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
@@ -65,6 +68,15 @@
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/reference_factory.hpp"
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
namespace ck_tile::builder::factory {
@@ -87,56 +99,6 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// Reference algorithm (simplest implementation for validation)
template <typename T>
concept IsReferenceAlgorithm = ConvAlgorithmDescriptor<T> && requires {
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
};
// CK Tile kernel
template <typename T>
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
concept IsXdlV3Algorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
template <typename T>
concept IsXdlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
template <typename T>
concept IsWmmaAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
template <typename T>
concept IsDlAlgorithm =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
// XDL-based kernel with large tensor support
template <typename T>
concept IsLargeTensorAlgorithm =
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
@@ -145,35 +107,35 @@ constexpr auto make_conv_instance()
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// Reference algorithm supports all directions
if constexpr(IsReferenceAlgorithm<AlgoType>)
if constexpr(ReferenceAlgorithm<AlgoType>)
{
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
// CK Tile supports common factory for each direction
else if constexpr(IsTileAlgorithm<AlgoType>)
else if constexpr(TileAlgorithm<AlgoType>)
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
// Forward direction (supports most algorithm variants)
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>)
if constexpr(FwdXdlV3Algorithm<AlgoType>)
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsXdlAlgorithm<AlgoType>)
else if constexpr(FwdXdlAlgorithm<AlgoType>)
{
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsWmmaAlgorithm<AlgoType>)
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
{
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsDlAlgorithm<AlgoType>)
else if constexpr(FwdDlAlgorithm<AlgoType>)
{
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
else if constexpr(LargeTensorAlgorithm<AlgoType>)
{
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
@@ -197,10 +159,55 @@ constexpr auto make_conv_instance()
// Backward weight direction (will expand with more algorithms in the future)
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
static_assert(false,
"Backward weight convolution: Only reference and tile algorithms "
"supported currently. "
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
if constexpr(BwdXdlAlgorithm<AlgoType>)
{
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdXdlV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdTwoStageXdlAlgorithm<AlgoType>)
{
return
typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdDlAlgorithm<AlgoType>)
{
return typename ConvBwdWeightDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
{
return
typename ConvBwdWeightMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdTwoStageWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightTwoStageWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
Instance{};
}
else if constexpr(BwdWmmaAlgorithm<AlgoType>)
{
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
{
return typename ConvBwdWeightMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
Instance{};
}
else
{
static_assert(
false,
"No suitable backward weight convolution kernel factory found for the provided "
"ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, "
"XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage "
"WMMA V3, WMMA, or Multi-D WMMA V3 variant.");
}
}
else
{

View File

@@ -24,10 +24,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -48,7 +48,7 @@ struct ConvFwdDlFactory
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
@@ -64,7 +64,7 @@ struct ConvFwdDlFactory
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
@@ -80,7 +80,7 @@ struct ConvFwdDlFactory
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
// C Thread Transfer from descriptor
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
@@ -89,18 +89,18 @@ struct ConvFwdDlFactory
// The DL forward convolution kernel class instance
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
SPATIAL_DIM,
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Layouts::OutLayout,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,

View File

@@ -26,68 +26,65 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdLargeTensorFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
static constexpr auto FWD_CONV_SPECIALIZATION =
internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<BASE_ALGORITHM>();
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<BASE_ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<BASE_ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER =
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance with large tensor support.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
ALGORITHM.num_gemm_k_prefetch_stages,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -106,8 +103,8 @@ struct ConvFwdLargeTensorFactory
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
LOOP_SCHEDULER>;
};

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
@@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -55,27 +56,27 @@ struct ConvFwdXdlV3Factory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
@@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -108,8 +109,8 @@ struct ConvFwdXdlV3Factory
C_BLOCK_TRANSFER.scalar_per_vector,
BLOCK_GEMM.scheduler,
BLOCK_GEMM.pipeline_version,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
IS_DIRECT_LOAD>;
};

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -52,27 +52,27 @@ struct ConvFwdWmmaFactory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,

View File

@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
@@ -39,6 +39,7 @@ struct ConvFwdXdlFactory
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -50,27 +51,27 @@ struct ConvFwdXdlFactory
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Layouts::OutLayout,
typename Types::InDataType,
typename Types::WeiDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Types::OutComputeType,
typename Types::DsDataType,
typename Types::OutDataType,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,
@@ -80,10 +81,10 @@ struct ConvFwdXdlFactory
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
@@ -102,10 +103,10 @@ struct ConvFwdXdlFactory
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
typename Types::AComputeType,
typename Types::BComputeType,
typename Types::InComputeType,
typename Types::WeiComputeType,
LOOP_SCHEDULER,
ALGORITHM.num_groups_to_merge>;
ALGORITHM.num_conv_groups_to_merge>;
};
} // namespace ck_tile::builder::factory

View File

@@ -10,27 +10,28 @@
namespace ck_tile::builder::factory::internal {
// Block transfer parameters for A or B tensor.
template <size_t ThreadClusterRank = 3>
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
ck::Array<size_t, ThreadClusterRank> src_access_order{};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
};
template <auto TRANSFER>
constexpr BlockTransfer SetFwdConvBlockTransfer()
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
return BlockTransfer{
return BlockTransfer<>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
@@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
};
}
template <auto TRANSFER>
constexpr auto SetBwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
constexpr auto array_length = block_order.order.size();
static_assert(block_order.order.size() == src_order.order.size(),
"Mismatched size between block order and src order");
if constexpr(array_length == 3)
{
return BlockTransfer<3>{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0],
block_order.order[1],
block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.lds_padding = lds_cfg.lds_padding,
};
}
else if constexpr(array_length == 4)
{
return BlockTransfer<4>{
.thread_cluster_dims = {block_xfer.k_batch_size,
block_xfer.k0,
block_xfer.m_n,
block_xfer.k1},
.thread_cluster_order = {block_order.order[0],
block_order.order[1],
block_order.order[2],
block_order.order[3]},
.src_access_order = {src_order.order[0],
src_order.order[1],
src_order.order[2],
src_order.order[3]},
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
.lds_padding = lds_cfg.lds_padding,
};
}
else
{
static_assert(false, "Internal error: Unsupported array length");
}
}
// Block transfer parameters for C tensor.
struct CBlockTransfer
{

View File

@@ -62,14 +62,15 @@ consteval auto GetElementwiseOp()
}
template <auto Sig>
struct ElementwiseOps
struct ConvElementwiseOps
{
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
using AElementwiseOp = typename decltype(input_op)::Op;
using BElementwiseOp = typename decltype(weight_op)::Op;
using CDEElementwiseOp = typename decltype(output_op)::Op;
using InElementwiseOp = typename decltype(input_op)::Op;
using WeiElementwiseOp = typename decltype(weight_op)::Op;
using OutElementwiseOp = typename decltype(output_op)::Op;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
}
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTensorLayouts
{
@@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM,
DIR>{};
SPATIAL_DIM>{};
}
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return EmptyAuxiliaryTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct ConvTensorLayouts
{
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -33,7 +33,7 @@ struct DataTypeToCK<DataType::FP32>
using type = float;
};
template <>
struct DataTypeToCK<DataType::INT32>
struct DataTypeToCK<DataType::I32>
{
using type = int32_t;
};
@@ -156,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
}
template <auto Signature>
struct FwdConvTensorDataTypes
struct ConvTensorDataTypes
{
static constexpr auto input_types =
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
@@ -165,20 +165,17 @@ struct FwdConvTensorDataTypes
static constexpr auto output_types =
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
using ADataType = typename decltype(input_types.first)::type;
using AComputeType = typename decltype(input_types.second)::type;
using BDataType = typename decltype(weight_types.first)::type;
using BComputeType = typename decltype(weight_types.second)::type;
using InDataType = typename decltype(input_types.first)::type;
using InComputeType = typename decltype(input_types.second)::type;
using WeiDataType = typename decltype(weight_types.first)::type;
using WeiComputeType = typename decltype(weight_types.second)::type;
using OutDataType = typename decltype(output_types.first)::type;
using OutComputeType = typename decltype(output_types.second)::type;
using AccDataType =
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
Signature.data_type>())::type;
using EDataType = typename decltype(output_types.first)::type;
// This is the "compute" type for output.
using CShuffleDataType = typename decltype(output_types.second)::type;
// Data types for the auxiliary tensors (e.g., bias).
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
@@ -37,7 +38,7 @@ struct BlockGemmSpec
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval BlockGemmSpec SetBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
@@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler()
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
constexpr auto pipeline_version = ALGORITHM.pipeline_version;
using ck_pipeline = ck::PipelineVersion;
switch(pipeline_version)
{
@@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(specialization)
{
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unknown ConvFwdSpecialization";
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
default: throw "Unsupported ConvSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
SetBwdWeightConvSpecialization()
{
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
switch(specialization)
{
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
case ConvSpecialization::FILTER_3x3:
throw "FILTER_3x3 is not supported for backward weight convolution.";
default: throw "Unsupported ConvSpecialization";
}
}

View File

@@ -26,11 +26,11 @@ struct ReferenceFactory
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using InDataType = typename Types::ADataType;
using WeiDataType = typename Types::BDataType;
using OutDataType = typename Types::EDataType;
using InDataType = typename Types::InDataType;
using WeiDataType = typename Types::WeiDataType;
using OutDataType = typename Types::OutDataType;
struct Instance
{

View File

@@ -63,10 +63,7 @@ struct GemmAlgorithmInfo
OutputTileTransferInfo c_tile_transfer;
builder::PipelineVersion pipeline_version;
builder::PipelineScheduler pipeline_scheduler;
std::variant<builder::ConvFwdSpecialization,
builder::ConvBwdDataSpecialization,
builder::ConvBwdWeightSpecialization>
conv_specialization;
builder::ConvSpecialization conv_specialization;
builder::GemmPadding padding;
};

View File

@@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction()
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
/// @tparam Instance The device kernel instance type.
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
/// `builder::ConvBwdWeightSpecialization` enum value.
/// @return A `builder::ConvSpecialization` enum value.
template <typename Instance>
constexpr auto conv_spec()
{
using InstTraits = InstanceTraits<Instance>;
using enum builder::ConvSpecialization;
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum builder::ConvFwdSpecialization;
switch(InstTraits::kConvForwardSpecialization)
{
case Default: return DEFAULT;
@@ -221,8 +219,6 @@ constexpr auto conv_spec()
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
using enum builder::ConvBwdDataSpecialization;
switch(InstTraits::kConvBwdDataSpecialization)
{
case Default: return DEFAULT;
@@ -232,8 +228,6 @@ constexpr auto conv_spec()
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
{
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
using enum builder::ConvBwdWeightSpecialization;
switch(InstTraits::kConvBwdWeightSpecialization)
{
case Default: return DEFAULT;

View File

@@ -35,10 +35,10 @@ struct ReferenceCommonTraits
typename builder::factory::internal::LayoutToCK<SIGNATURE.output.config.layout>::type;
// Data types - extract from factory's type helper
using Types = builder::factory::internal::FwdConvTensorDataTypes<SIGNATURE>;
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using EDataType = typename Types::EDataType;
using Types = builder::factory::internal::ConvTensorDataTypes<SIGNATURE>;
using ADataType = typename Types::InDataType;
using BDataType = typename Types::WeiDataType;
using EDataType = typename Types::OutDataType;
using AccDataType = float; // Reference uses float accumulation
// Elementwise operations - reference only supports PassThrough

View File

@@ -7,6 +7,7 @@
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/filter_extent.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_initialization.hpp"
@@ -71,11 +72,10 @@ struct Args<SIGNATURE>
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
// TODO: We shouldn't need to call into an internal namespace here.
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts =
factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
ConvTensorLengths<SPATIAL_DIM> lengths;
@@ -89,9 +89,9 @@ struct Args<SIGNATURE>
FilterExtent<SPATIAL_DIM> input_left_pad;
FilterExtent<SPATIAL_DIM> input_right_pad;
Ops::AElementwiseOp a_elementwise_op;
Ops::BElementwiseOp b_elementwise_op;
Ops::CDEElementwiseOp cde_elementwise_op;
Ops::InElementwiseOp a_elementwise_op;
Ops::WeiElementwiseOp b_elementwise_op;
Ops::OutElementwiseOp cde_elementwise_op;
/// This function returns the `TensorDescriptor` corresponding to
/// the input-tensor of the convolution problem. This can then
@@ -106,7 +106,7 @@ struct Args<SIGNATURE>
// function.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
typename Layouts::ALayout>(param);
typename Layouts::InLayout>(param);
using Extent = typename InputDescriptor::Extent;
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -120,7 +120,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
typename Layouts::BLayout>(param);
typename Layouts::WeiLayout>(param);
using Extent = typename WeightDescriptor::Extent;
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -134,7 +134,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
typename Layouts::ELayout>(param);
typename Layouts::OutLayout>(param);
using Extent = typename OutputDescriptor::Extent;
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -182,6 +182,12 @@ struct Inputs<SIGNATURE>
{
void* input;
void* weight;
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
}
};
/// @brief `Outputs` specialization for forward convolution.
@@ -194,68 +200,13 @@ template <auto SIGNATURE>
struct Outputs<SIGNATURE>
{
void* output;
};
/// @brief `UniqueInputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see UniqueInputs
/// @see ValidUniqueInputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct UniqueInputs<SIGNATURE>
{
DeviceBuffer input_buf;
DeviceBuffer weight_buf;
/// @see ValidUniqueInputs
Inputs<SIGNATURE> get()
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
return {
.input = input_buf.get(),
.weight = weight_buf.get(),
};
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
}
};
/// @brief `UniqueOutputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see UniqueOutputs
/// @see ValidUniqueOutputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct UniqueOutputs<SIGNATURE>
{
DeviceBuffer output_buf;
/// @see ValidUniqueOutputs
Outputs<SIGNATURE> get()
{
return {
.output = output_buf.get(),
};
}
};
/// @brief `alloc_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
ValidUniqueInputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
{
return {
.input_buf = alloc_tensor_buffer(args.make_input_descriptor()),
.weight_buf = alloc_tensor_buffer(args.make_weight_descriptor()),
};
}
/// @brief `init_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
@@ -269,34 +220,4 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
}
/// @brief `alloc_outputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_outputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
ValidUniqueOutputs<SIGNATURE>
UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
{
return {
.output_buf = alloc_tensor_buffer(args.make_output_descriptor()),
};
}
/// @brief `validate()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see validate()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
ValidationReport
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected)
{
ValidationReport report;
report.check("output", args.make_output_descriptor(), actual.output, expected.output);
return report;
}
} // namespace ck_tile::builder::test

View File

@@ -27,7 +27,7 @@ template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::AElementwiseOp elementwise_a,
Ops::BElementwiseOp elementwise_b,
Ops::CDEElementwiseOp elementwise_cde) {
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde) {
{
conv.MakeArgument(p_a,
p_b,

View File

@@ -0,0 +1,634 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/error.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
#include "ck/utility/type_convert.hpp"
#include <iostream>
#include <locale>
#include <string>
#include <string_view>
#include <syncstream>
#include <concepts>
#include <limits>
/// This file contains a few debugging utilities, mainly focused around
/// tensor data. The idea is that the functionality in this file is not
/// necessarily used in any testing directly, but is available for the
/// programmer to help with debugging problems. These utilities themselves
/// should be tested just the same, though, so that they don't undergo
/// bitrot while they are not actively being used.
namespace ck_tile::builder::test {
namespace detail {
/// @brief Custom number punctuation for CK-Builder debugging.
///
/// During debugging, the locale is usually left to the default C locale.
/// The C locale does not have any thousands separator, which makes
/// large numbers hard to read. This is a specialization of the default
/// C++ number punctuation (`std::numpunct`) which separates thousands
/// using `'`, which helps getting a quick overview of the magnitude of
/// a number. This character is chosen because C++14 allows number literals
/// to have this character.
///
/// @note When using this locale, be sure to restore the old locale in the
/// event that the user actually wants to use a non-standard locale.
///
/// @see std::numpunct
struct numpunct : std::numpunct<char>
{
char do_thousands_sep() const override { return '\''; }
std::string do_grouping() const override
{
// See std::numpunct, this separates by thousands.
return "\3";
}
};
} // namespace detail
/// @brief Print information about a tensor descriptor.
///
/// This function dumps useful information from a tensor descriptor to a
/// stream, `std::cout` by default. This includes the number of elements
/// in the tensor, the size of the backing space, lengths, strides, etc.
///
/// @note All information is printed using a lightly modified locale to
/// get a unified printing experience. The original locale in `stream` is
/// temporarily replaced, but restored before the function returns.
///
/// @tparam DT The tensor element datatype
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param name A name for the tensor descriptor.
/// @param desc The tensor descriptor to print.
/// @param out The stream to print to, `std::cout` by default.
template <DataType DT, size_t RANK>
void print_descriptor(std::string_view name,
const TensorDescriptor<DT, RANK>& desc,
std::ostream& out = std::cout)
{
// Create a custom stream with a completely new config (locale,
/// precision, fill, etc). Use an osyncstream to buffer the output
/// while were at it (its not likely to help a lot, but why not).
std::osyncstream stream(out.rdbuf());
stream.imbue(std::locale(std::locale(), new detail::numpunct{}));
// Print name along with some generic info
const auto size = desc.get_element_size();
const auto space = desc.get_element_space_size();
const auto bytes = desc.get_element_space_size_in_bytes();
const auto packed = desc.is_packed();
stream << "Descriptor \"" << name << "\":\n"
<< " data type: " << DT << '\n'
<< " size: " << size << " elements\n"
<< " space: " << space << " elements (" << bytes << " bytes)\n"
<< " lengths: " << desc.get_lengths() << '\n'
<< " strides: " << desc.get_strides() << '\n'
<< " packed: " << (packed ? "yes" : "no") << std::endl;
}
/// @brief User configuration for printing tensors.
///
/// This structure houses some configuration fields for customizing how tensors
/// are printed. The default is usually good, though `TensorPrintConfig::unlimited()`
/// is useful if you want to print the entire tensor to the output regardless of size.
struct TensorPrintConfig
{
/// @brief A limit for the number of columns in a tensor row to print.
///
/// Each row of a tensor will be printed as a sequence of values. At most
/// this number of values are printed, if there are more, `row_skip_val`
/// will be printed in between.
size_t col_limit = 10;
/// @brief A limit for the number of rows in a 2D matrix to print
///
/// Tensors with rank higher than 1 are printed as a single matrix or a series
/// of matrix slices. At most this number of rows of the matrix will be printed.
/// If there are more rows, a row of `matrix_row_skip_val` and possibly
/// `row_skip_val` will be printed in between.
size_t row_limit = 10;
/// @brief A limit for the number of 2D tensor slices to print.
///
/// Tensors with rank higher than 2 are flattened into a sequence of slices. At
/// most this number of slices will be printed.
size_t slice_limit = 8;
/// @brief Text to print at the start of a row of values.
///
/// This is used by `TensorPrinter`, and printed at the start of a row of tensor
/// values.
std::string_view row_prefix = " ";
/// @brief Text to print between fields of a row.
///
/// This is used by `TensorPrinter`, and printed between each value of a row of
/// tensor values.
std::string_view row_field_sep = " ";
/// @brief Text to print when skipping some number of row values.
///
/// This is used by `TensorPrinter`, and printed instead of some number of values
/// when the number of values in a row is too large to all print.
std::string_view row_skip_val = "...";
/// @brief Text to print when skipping a row of a matrix.
///
/// This is used by `TensorPrinter`, and printed instead of a value when some
/// number of rows is skipped when printing a matrix. This is similar to
/// `row_skip_val`, except in the vertical direction. Note that ALL values
/// in the skip row is printed this way.
std::string_view matrix_row_skip_val = "...";
/// @brief The precision of tensor floating point values.
///
/// Set the number of decimal digits that is printed for a floating point value.
int float_precision = 3;
/// @brief Return the default print config, but without any printing limits.
///
/// This is useful if you want to print the *entire* tensor, but be aware that
/// this may print a lot of data if the tensor is large!
constexpr static TensorPrintConfig unlimited()
{
return {
.col_limit = std::numeric_limits<size_t>::max(),
.row_limit = std::numeric_limits<size_t>::max(),
.slice_limit = std::numeric_limits<size_t>::max(),
};
}
};
namespace detail {
/// @brief Iterate over a range of values, but limit the amount of iterations.
///
/// Iterate over values `0..n`, but if `limit > n`, only iterate over the
/// first and last few (`limit // 2)` items. This can be used to iterate over
/// large ranges in a way that not too many values are visited. Its primarily
/// used when printing tensors so that not all values of a giant tensor are
/// dumped to the user's terminal.
///
/// @param n The total number of items to iterate over.
/// @param limit The maximum number of items to iterate over. Use even values
/// for best results, as this will lead to the same amount of values in the
/// "begin" and "end" sections.
/// @param f A functor to invoke for each element. The sole parameter is the
/// index.
/// @param delim A functor to invoke between the begin and end sections. This
/// function is only invoked if any items are skipped at all.
void limited_foreach(size_t n, size_t limit, auto f, auto delim)
{
if(n <= limit)
{
for(size_t i = 0; i < n; ++i)
f(i);
}
else
{
const auto begin_count = (limit + 1) / 2; // Round up in case `delim` is odd.
const auto end_count = limit / 2;
const auto skip_count = n - limit;
for(size_t i = 0; i < begin_count; ++i)
f(i);
delim(skip_count);
for(size_t i = n - end_count; i < n; ++i)
f(i);
}
};
/// @brief Output stream requirements for use with `TensorPrinter`.
///
/// The `TensorPrinter` does not write to an ostream directly, but rather writes to
/// a custom stream object. This is mainly so that the user of `TensorPrinter` can
/// get more details than directly with an ostream. Basically, a valid implementation
/// of `TensorPrintStream` exposes 3 things:
/// - A way to print (stringified) tensor elements.
/// - A way to print arbitrary text messages. These are mostly for formatting. This
/// should be implemented using varargs which are directly folded into an ostream,
/// so that <iomanip> functions can be used.
/// - A way to query the max width of any `val` field.
///
/// @see TensorPrinter for more information.
template <typename Stream>
concept TensorPrintStream = requires(Stream& stream, std::string_view val) {
{ stream.max_width } -> std::convertible_to<size_t>;
{ stream.val(val) } -> std::same_as<void>;
{ stream.msg() } -> std::same_as<void>;
{ stream.msg("msg") } -> std::same_as<void>;
{ stream.msg(std::setw(3), std::setfill(4), "msg", val) } -> std::same_as<void>;
};
/// @brief Utility to print tensors.
///
/// This structure implements the main logic for printing tensors to a stream.
/// In order to help with formatting, the `TensorPrinter` abstracts over a custom
/// stream type, see `TensorPrintStream`. This type is actually mostly an internal
/// helper and mainly used by `print_tensor`. Its supposed to be constructed
/// manually, but see the field docs for what is required.
///
/// @tparam DT The data type of the tensor to print.
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to print.
///
/// @see print_tensor
template <DataType DT, size_t RANK>
struct TensorPrinter
{
/// The name of this tensor. This will be used during printing to add extra
/// clarity about what the user is seeing.
std::string_view name;
/// Configuration details of how to print the tensor. This should be able to
/// be specified by the user, but the default is good in most cases.
TensorPrintConfig config;
/// The lengths of the tensor to print. These values are directly from
/// `TensorDescriptor::get_lengths()`, stored here to avoid querying them
/// repeatedly.
Extent<RANK> lengths;
/// The strides of the tensor to print. These values are directly from
/// `TensorDescriptor::get_strides()`, stored here to avoid querying them
/// repeatedly.
Extent<RANK> strides;
/// The tensor's backing buffer. This memory should be host-accessible, for
/// example by copying it back to the host first.
const void* h_buffer;
/// A common stringstream for stringifying tensor values. This is here mostly
/// so that we can cache the internal allocation.
std::stringstream ss;
/// @brief Low-level tensor value stringifying function.
///
/// Print value `value` to the stringstream `ss` (member value). This function
/// is the actual low-level printing function that prints each element of the
/// tensor. In order to get a robust printing implementation, the value is written
/// directly into a stringstream, which is then further processed to be actually
/// written to the output. This way, the format doesn't depend on the ostream
/// configuration.
///
/// @param value The value to print to the stream.
void stringify_value(const void* value)
{
if constexpr(DT == DataType::UNDEFINED_DATA_TYPE)
{
ss << "??";
return;
}
using CKType = detail::cpp_type_t<DT>;
const auto ck_value = *static_cast<const CKType*>(value);
if constexpr(DT == DataType::I32 || DT == DataType::I8 || DT == DataType::U8)
ss << ck_value;
else if constexpr(DT == DataType::FP64 || DT == DataType::FP32)
ss << std::fixed << std::setprecision(config.float_precision) << ck_value;
else if constexpr(DT == DataType::FP16 || DT == DataType::BF16 || DT == DataType::FP8 ||
DT == DataType::BF8)
ss << std::fixed
<< std::setprecision(config.float_precision)
// Note: We are using CK types here (cpp_type_t uses DataTypeToCK), so
// use CK's type_convert function.
<< ::ck::type_convert<float>(ck_value);
else
// TODO: Tuple types? Currently not implemented in DataTypeToCK...
static_assert(false, "stringify_value unsupported data type, please implement");
}
/// @brief Print the value at an index to a stream.
///
/// This function reads the value at `index` and prints it to `stream` (using
/// `stream.val(...)`).
///
/// @param stream The stream to print to.
/// @param index The index in the tensor of the value to print.
void print_value(TensorPrintStream auto& stream, const Extent<RANK>& index)
{
const auto offset = calculate_offset(index, strides);
const auto* value_ptr =
&static_cast<const std::byte*>(h_buffer)[offset * data_type_sizeof(DT)];
// Reset the stream without allocating.
// ss.str("") allocates...
ss.clear();
ss.seekg(0);
ss.seekp(0);
stringify_value(value_ptr);
// ss.view() returns a view of the ENTIRE buffer, which may have
// lingering data since we used seekp() and seekg() to reset the
// stream. For some reason std::stringstream works this way...
// Fortunately tellp() returns how many bytes we've actually
// written.
const auto view = ss.view().substr(0, ss.tellp());
stream.val(view);
}
/// @brief Print a 1D row to a stream.
///
/// Print a row of tensor values to the stream. This function is used for both
/// 1D tensors and for rows of 2D tensors, in which the base coordinate is given
/// by `index`. Note that the print configuration is taken into account to avoid
/// flooding the user's terminal with values.
///
/// @param stream The stream to print to.
/// @param index The index of the row to print. The rightmost index element is
/// ignored, as that is the index of the value _within_ the row.
void print_row(TensorPrintStream auto& stream, Extent<RANK>& index)
{
// See note in `print_matrix`.
stream.msg(config.row_prefix);
limited_foreach(
lengths[RANK - 1],
config.col_limit,
[&](auto i) {
stream.msg(config.row_field_sep);
index[RANK - 1] = i;
print_value(stream, index);
},
[&]([[maybe_unused]] auto skip_count) {
stream.msg(config.row_field_sep);
// Note: Not using stream.val(...) here because we don't want this
// field to partake in max_width computation, nor do we want to
// pad it to the max width.
stream.msg(config.row_skip_val);
});
stream.msg('\n');
}
/// @brief Print a 2D matrix to a stream.
///
/// Print a matrix of tensor values to the stream. This function is used for both
/// 2D and slices of higher-dimensional tensors, in which the base coordinate is
/// given by `index`. Note that the print configuration is taken into account to
/// avoid flooding the user's terminal with values.
///
/// @param stream The stream to print to.
/// @param index The index of the row to print. The 2 rightmost index elements are
/// ignored, as those are the indices of values _within_ the matrix.
void print_matrix(TensorPrintStream auto& stream, Extent<RANK>& index)
{
limited_foreach(
lengths[RANK - 2],
config.row_limit,
[&](auto i) {
index[RANK - 2] = i;
print_row(stream, index);
},
[&]([[maybe_unused]] auto row_skip_count) {
// When we encounter a skip row, continue with the same logic
// as printing 1D tensor rows. Instead of actual values, we will
// simply print MATRIX_ROW_SKIP_VAL (usually something like "...").
stream.msg(config.row_prefix);
limited_foreach(
lengths[RANK - 1],
config.col_limit,
[&]([[maybe_unused]] auto i) {
stream.msg(config.row_field_sep);
// Note: We're using `stream.val(...)` here because we *do* want this field
// to partake in max_width computation, and we *do* want to pad it like
// value fields. This is so that these appear the same width as actual
// values, so that everything is neatly aligned. This also ensures that if
// there are no skip values, then the size of the skip field is not taken
// into account.
stream.val(config.matrix_row_skip_val);
},
[&]([[maybe_unused]] auto col_skip_count) {
stream.msg(config.row_field_sep);
// Note: Not using stream.val(...) here because we don't want this
// field to partake in max_width computation, nor do we want to
// pad it to the max width.
stream.msg(config.row_skip_val);
});
stream.msg('\n');
});
}
/// @brief Print a tensor to a stream.
///
/// This is the main tensor printing function. It calls `print_row` or `print_matrix`
/// (possibly repeatedly) as required. This function prints the entire tensor in
/// `h_buffer` regardless.
///
/// @param stream The stream to print to.
void print_tensor(TensorPrintStream auto& stream)
{
Extent<RANK> zero_coord = {};
if constexpr(RANK == 0)
{
// 0D case: just print the one value
stream.msg(config.row_prefix);
stream.msg(config.row_field_sep);
print_value(stream, zero_coord);
stream.msg('\n');
}
else if constexpr(RANK == 1)
{
// 1D case: dump everything on one line
print_row(stream, zero_coord);
}
else if constexpr(RANK == 2)
{
// 2D case: print a 2D matrix
print_matrix(stream, zero_coord);
}
else
{
// For higher dimensions, print each window as a slice
// We want to limit the *total* number of slices using `slice_limit`,
// not the number in each axis. So flatten the remaining dimensions.
// This also avoids recursion in this function in general.
// First get the shape minus the 2 inner dimensions
Extent<RANK - 2> outer_shape;
std::copy_n(lengths.begin(), RANK - 2, outer_shape.begin());
NdIter iter(outer_shape);
detail::limited_foreach(
iter.numel(),
config.slice_limit,
[&](auto outer_flat_index) {
// Now decode the outer index and turn it back into a complete index
const auto outer_index = iter(outer_flat_index);
Extent<RANK> index = {};
std::copy_n(outer_index.begin(), RANK - 2, index.begin());
// Print an extra separating line between two slices
if(outer_flat_index != 0)
stream.msg('\n');
// Print an information header about the current slice
stream.msg("Tensor \"", name, "\", slice [");
for(auto x : outer_index)
stream.msg(x, ", ");
stream.msg(":, :]\n");
// And print is as matrix
print_matrix(stream, index);
},
[&](auto skip_count) { stream.msg("\n(skipping ", skip_count, " slices...)\n"); });
}
}
};
/// @brief Implementation of `TensorPrintStream` to figure out the maximum
/// width of a field.
///
/// In order to produce neatly aligned tensors, where all values of each row
/// appear on the same columns, we have to figure out the maximum width of
/// each field. This print stream helps with that: It does not actually print
/// anything, it just figures out the maximum width of any value (not message).
///
/// @details OK, this function does actually print things, but only to an
/// internal `stringstream`. This is so that we can easily figure out the
/// width of the field (in bytes), just by counting the amount of bytes
/// written into the string stream.
///
/// @see TensorPrintStream
struct MaxFieldWidthStream
{
size_t max_width = 0;
/// @brief Print a tensor value to the stream
///
/// "Print" a value to the stream. This function figures out the width
/// of the value when printed, and then composes it with `max_width` to
/// figure out the total maximum.
///
/// @param value The value to print.
void val(std::string_view value) { max_width = std::max(max_width, value.size()); }
/// @brief Print a message to the stream.
///
/// "Print" a non-value message to the stream. In this implementation,
/// everything is discarded.
///
/// @tparam Args the types of the values to print.
///
/// @param args The values to print.
template <typename... Args>
void msg([[maybe_unused]] const Args&... args)
{
}
};
/// @brief Implementation of `TensorPrintStream` which actually prints.
///
/// In contrast to `MaxFieldWidthStream`, this function actually prints
/// to an ostream, taking the value produced by that type into account.
struct OutputStream
{
std::ostream& stream;
// The maximum width of each tensor value.
size_t max_width;
/// @brief Print a tensor value to the stream
///
/// Actually print a value into the stream, (right-)padding it to
/// `max_width`.
///
/// @param value The value to print.
void val(std::string_view value)
{
stream << std::setfill(' ') << std::setw(max_width) << value;
}
/// @brief Print a message to the stream.
///
/// This prints a non-value message directly to the ostream, as if
/// folded via `operator<<`.
///
/// @tparam Args the types of the values to print.
///
/// @param args The values to print.
template <typename... Args>
void msg(const Args&... args)
{
(stream << ... << args);
}
};
} // namespace detail
/// @brief Print device tensor values to an ostream.
///
/// Print the values of a tensor to an ostream. This function neatly formats
/// the tensor according to `config`, tabulating the values so that they are
/// vertically aligned and skipping values to prevent flooding the terminal.
/// With the default config, this function is good to get a quick overview
/// of what a tensor looks like. For a more complete overview, consider
/// supplying `TensorPrintConfig::unlimited()` to get everything (but beware
/// of flooding the terminal). Tensors are printed with the rightmost-dimension
/// as inner dimension, these values appear on the same row in the output.
///
/// @tparam DT The data type of the tensor.
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param name A name for the tensor. This will be used to add some extra identifying
/// information during printing.
/// @param desc The descriptor for the tensor memory layout.
/// @param d_buffer The tensor's actual data buffer. This is expected to be
/// _device accessible_ memory, as its copied back to the host first.
/// @param config Tensor printing configuration. This allows tweaking some details
/// of the printing process.
/// @param out The ostream to print to, `std::cout` by default.
template <DataType DT, size_t RANK>
void print_tensor(std::string_view name,
const TensorDescriptor<DT, RANK>& desc,
const void* d_buffer,
TensorPrintConfig config = {},
std::ostream& out = std::cout)
{
// Copy memory to the host (printing from device is sketchy)
const auto space = desc.get_element_space_size_in_bytes();
std::vector<std::byte> h_buffer(space);
check_hip(hipMemcpy(h_buffer.data(), d_buffer, space, hipMemcpyDeviceToHost));
// Create a custom stream with a completely new config (locale,
/// precision, fill, etc). Use an osyncstream to buffer the output
/// while were at it (its not likely to help a lot, but why not).
std::osyncstream stream(out.rdbuf());
stream.imbue(std::locale(std::locale(), new detail::numpunct{}));
// Print a header for the entire tensor (regardless of if there are multiple slices).
stream << "Tensor \"" << name << "\": shape = " << desc.get_lengths() << "\n";
detail::TensorPrinter<DT, RANK> printer = {
.name = name,
.config = config,
.lengths = desc.get_lengths(),
.strides = desc.get_strides(),
.h_buffer = h_buffer.data(),
.ss = std::stringstream(),
};
// We're actually going to print twice: once to figure out the
// maximum width of the fields, and once to actually print to the stream.
// Print once to figure out the maximum field width.
detail::MaxFieldWidthStream max_field_width;
printer.print_tensor(max_field_width);
// Actually print to the output stream.
detail::OutputStream tensor_out = {
.stream = stream,
.max_width = max_field_width.max_width,
};
printer.print_tensor(tensor_out);
}
} // namespace ck_tile::builder::test

View File

@@ -81,4 +81,15 @@ inline DeviceBuffer alloc_buffer(size_t size)
return DeviceBuffer(d_buf);
}
/// @brief "Align" an offset to a multiple of a particular alignment.
///
/// Returns `addr` aligned to the next multiple of `alignment`.
///
/// @param addr The address to align.
/// @param alignment The alignment.
inline size_t align_fwd(size_t addr, size_t alignment)
{
return addr % alignment == 0 ? addr : addr - addr % alignment + alignment;
}
} // namespace ck_tile::builder::test

View File

@@ -7,6 +7,7 @@
#include <array>
#include <vector>
#include <sstream>
#include <iosfwd>
#include <concepts>
#include <algorithm>
#include <hip/hip_runtime.h>
@@ -123,6 +124,33 @@ struct Extent : std::array<size_t, RANK>
template <typename... T>
Extent(T...) -> Extent<sizeof...(T)>;
/// @brief Extent printer
///
/// This function implements an ostream printing overload for `Extent`, so that
/// they can be printed in the usual `stream << extent` fashion.
///
/// @tparam RANK Rank (number of spatial dimensions) of the extent.
///
/// @param stream The stream to print the extent to.
/// @param extent The extent to print to the stream.
template <size_t RANK>
std::ostream& operator<<(std::ostream& stream, const Extent<RANK>& extent)
{
stream << '[';
bool first = true;
for(const auto x : extent)
{
if(first)
first = false;
else
stream << ", ";
stream << x;
}
return stream << ']';
}
/// @brief Concept for automatically deriving tensor memory layout.
///
/// A `TensorStridesGenerator` is a type which can be used to automatically

View File

@@ -18,6 +18,102 @@
namespace ck_tile::builder::test {
/// @brief Utility structure for N-dimensional iteration using a flat index
///
/// This structure's main purpose is to "unmerge" a flattened index into a
/// multi-dimensional index, which helps when iterating over multi-dimensional
/// indices without having to write an arbitrary amount of nested for loops.
/// A minimal amount of precomputation must be done to do this efficiently,
/// which is handled in the constructor of this type.
///
/// @details Decoding a flat index into a multi-dimensional index is done by
/// first computing a reverse scan of the shape. These values can then be
/// used to decode the index in the usual way:
///
/// x = flat_idx / (size_y * size_z)
/// y = flat_idx % (size_y * size_z) / size_z
/// z = flat_idx % (size_y * size_z) % size_z
/// etc
///
/// The decode order is such that the innermost dimension (right in
/// the shape extent) changes the fastest.
///
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to
/// iterate.
template <size_t RANK>
struct NdIter
{
/// @brief Prepare N-dimensional iteration over a particular shape.
///
/// Precompute ashape into a form that can be used to easily decode a flat
/// index into a multi-dimensional index.
///
/// @param shape The shape to iterate over.
explicit NdIter(const Extent<RANK>& shape)
{
// Precompute shape_scan = [..., shape[-2] * shape[-1], shape[-1], 1]
numel_ = 1;
for(int i = RANK; i > 0; --i)
{
shape_scan_[i - 1] = numel_;
numel_ *= shape[i - 1];
}
}
/// @brief Unflatten a flat index into a multi-dimensional index
///
/// This applies the usual multi-dimensional indexing method over the
/// precomputed shape scan to get back a multi-dimensional index.
/// The decode order is such that the innermost dimension (right in
/// the shape extent) changes the fastest.
///
/// @param flat_index The "flattened" (1-dimensional) index of the tensor
///
/// @returns A multi-dimensional index into the tensor
///
/// @pre `0 <= flat_index < size()` (in other words, the `flat_index` must
/// be in bounds of the tensor shape that this `NdIter` was made from).
__host__ __device__ Extent<RANK> operator()(size_t flat_index) const
{
Extent<RANK> index = {};
auto idx = flat_index;
for(size_t i = 0; i < RANK; ++i)
{
const auto scanned_dim = shape_scan_[i];
index[i] = idx / scanned_dim;
idx %= scanned_dim;
}
return index;
}
/// @brief Return the total elements to iterate over
///
/// Get the total number of elements in the shape to iterate over. This value
/// can be used to construct a complete for loop to iterate over all indices
/// of a tensor, for example:
///
/// for(size_t i = 0; i < iter.numel(); ++i)
/// {
/// const auto index = iter(i);
/// use(index);
/// }
__host__ __device__ size_t numel() const { return numel_; }
private:
/// Reverse (right) scan of the shape to iterate over.
Extent<RANK> shape_scan_;
/// The total number of elements in the shape. This value turns out to be almost
/// always required when iterating over a shape, so just store it in this type
/// so that it is easily accessible.
size_t numel_;
};
template <size_t RANK>
NdIter(Extent<RANK>) -> NdIter<RANK>;
/// @brief Concept for constraining tensor iteration functors.
///
/// This concept checks that a functor has the correct signature for
@@ -50,28 +146,19 @@ constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256;
/// @tparam F The type of the callback to invoke. This function must be
/// compatible with execution as a __device__ function.
///
/// @param numel The total number of elements in the tensor.
/// @param shape_scan A right-exclusive scan of the shape of the tensor.
/// @param iter An NdIter instance to help iterating over the tensor.
/// @param f The callback to invoke for each index of the tensor. This
/// functor must be eligible for running on the GPU.
template <int BLOCK_SIZE, size_t RANK, typename F>
requires ForeachFunctor<F, RANK>
__global__ __launch_bounds__(BLOCK_SIZE) //
void foreach_kernel(const size_t numel, Extent<RANK> shape_scan, F f)
void foreach_kernel(NdIter<RANK> iter, F f)
{
const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x;
for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE)
for(size_t flat_idx = gid; flat_idx < iter.numel(); flat_idx += gridDim.x * BLOCK_SIZE)
{
// Compute the current index.
Extent<RANK> index = {};
size_t idx = flat_idx;
for(size_t i = 0; i < RANK; ++i)
{
const auto scanned_dim = shape_scan[i];
index[i] = idx / scanned_dim;
idx %= scanned_dim;
}
const auto index = iter(flat_idx);
// Then invoke the callback with the index.
f(index);
@@ -160,18 +247,12 @@ void tensor_foreach(const Extent<RANK>& shape, ForeachFunctor<RANK> auto f)
// order in the kernel is from large-to-small. Right layout is the
// easiest solution for that.
Extent<RANK> shape_scan;
size_t numel = 1;
for(int i = RANK; i > 0; --i)
{
shape_scan[i - 1] = numel;
numel *= shape[i - 1];
}
NdIter iter(shape);
// Reset any errors from previous launches.
(void)hipGetLastError();
kernel<<<occupancy * multiprocessors, block_size>>>(numel, shape_scan, f);
kernel<<<occupancy * multiprocessors, block_size>>>(iter, f);
check_hip(hipGetLastError());
}
@@ -179,7 +260,7 @@ void tensor_foreach(const Extent<RANK>& shape, ForeachFunctor<RANK> auto f)
///
/// This concept checks that a functor has the correct signature for
/// use with the `fill_tensor` function.
template <typename F, builder::DataType DT, size_t RANK>
template <typename F, DataType DT, size_t RANK>
concept FillTensorFunctor = requires(const F& f, const Extent<RANK>& index) {
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
};
@@ -199,7 +280,7 @@ concept FillTensorFunctor = requires(const F& f, const Extent<RANK>& index) {
/// @param f A functor used to get the value at a particular coordinate.
///
/// @see FillTensorFunctor
template <builder::DataType DT, size_t RANK>
template <DataType DT, size_t RANK>
void fill_tensor(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
FillTensorFunctor<DT, RANK> auto f)
@@ -218,7 +299,7 @@ void fill_tensor(const TensorDescriptor<DT, RANK>& desc,
///
/// This concept checks that a functor has the correct signature for
/// use with the `fill_tensor_buffer` function.
template <typename F, builder::DataType DT>
template <typename F, DataType DT>
concept FillTensorBufferFunctor = requires(const F& f, size_t index) {
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
};
@@ -239,7 +320,7 @@ concept FillTensorBufferFunctor = requires(const F& f, size_t index) {
/// @param f A functor used to get the value at a particular index.
///
/// @see FillTensorBufferFunctor
template <builder::DataType DT, size_t RANK>
template <DataType DT, size_t RANK>
void fill_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
FillTensorBufferFunctor<DT> auto f)
@@ -247,7 +328,19 @@ void fill_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); });
}
template <builder::DataType DT, size_t RANK>
/// @brief Utility for clearing tensor buffers to a particular value.
///
/// This function initializes all memory backing a particular tensor buffer to
/// one specific value, zero by default. Note that this function ignores strides,
/// and clears the entire buffer backing the tensor.
///
/// @tparam DT The tensor element datatype
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
///
/// @param desc The descriptor of the tensor to initialize.
/// @param buffer The memory of the tensor to initialize.
/// @param value The value to initialize the tensor buffer with.
template <DataType DT, size_t RANK>
void clear_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
void* buffer,
detail::cpp_type_t<DT> value = detail::cpp_type_t<DT>{0})

View File

@@ -5,6 +5,8 @@
#include <concepts>
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/validation.hpp"
/// This file is the main header for the CK-Builder testing system. A high-level
@@ -132,8 +134,8 @@ struct Outputs;
/// be created using `alloc_inputs()` and that an instance of the corresponding
/// `Inputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each input tensor.
/// @note A default implementation is provided for this type if `Inputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
@@ -151,8 +153,8 @@ struct UniqueInputs;
/// be created using `alloc_outputs()` and that an instance of the corresponding
/// `Outputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each output tensor.
/// @note A default implementation is provided for this type if `Outputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
@@ -197,6 +199,12 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @note A default implementation is provided for this function if `Inputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
@@ -207,22 +215,22 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueInputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args) = delete;
/// @brief Allocate inputs corresponding to a signature.
/// @brief Initialize inputs corresponding to a signature.
///
/// The `init_inputs()` function is used to initialize pseudo-random data
/// to the tensors specified in the Inputs structure. Implementors should
/// fill each of the tensors in `inputs` with appropriate random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param inputs The operation inputs to initialize with random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Inputs
/// @see tensor_initialization
template <auto SIGNATURE>
@@ -235,13 +243,16 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @note A default implementation is provided for this function if `Outputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Outputs
/// @see UniqueOutputs
/// @see alloc_buffer()
@@ -262,15 +273,15 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
/// were incorrect, and where (a subset of) those elements are located within
/// the tensor. See `ValidationReport` for more information about the report.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param actual The actual results, the results of the operation to-be-tested.
/// @param expected The expected results, the results of the reference implementation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see ValidationReport
template <auto SIGNATURE>
ValidationReport validate(const Args<SIGNATURE>& args,

View File

@@ -0,0 +1,199 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string_view>
/// testing.hpp requires developers of a type of SIGNATURE to implement
/// quite a lot of functionality for each SIGNATURE. For example, next
/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define
/// `UniqueInputs`, `UniqueOutputs`, `alloc_inputs`, `alloc_outputs`,
/// and `validate`. The implementation of these latter few functions
/// is usually quite straight forward and adds a bunch of copy-paste
/// overhead. The functionality in this file offers an alternative
/// route: By implementing some reflection functionality in `Inputs`
/// and `Outputs`, we can automatically derive most of the functionality.
namespace ck_tile::builder::test {
/// @brief Check whether an `Input` or `Output` struct can be reflected.
///
/// In order to avoid having to manually redefine a bunch of types related to
/// each `Inputs`/`Outputs` structure, those structures can also provide some
/// "reflection" functionality. To this end, they should implement
/// `static void reflect(const Args<SIGNATURE> args&, auto inspect)`, where `inspect`
/// is called with information about each field in the struct. In more detail,
/// the signature of the `inspect` function is as follows:
///
/// void inspect(
/// // A human-readable name for the tensor
/// std::string_view name,
/// // Descriptor for the tensor in memory, usually obtained via `args`.
/// const TensorDescriptor<DT, RANK>& desc,
/// // Member pointer to a field of `T`, which is a GPU-memory pointer
/// // to the relevant tensor memory.
/// void* T::* ptr);
///
/// Here, `T` is `Inputs<SIGNATURE>` or `Outputs<SIGNATURE>`.
///
/// @see Inputs
/// @see Outputs
template <typename T, auto SIGNATURE>
concept TensorReflectable = requires(const Args<SIGNATURE>& args) {
{
T::reflect(args,
[]([[maybe_unused]] std::string_view name,
// Note: This will be a TensorDescriptor<DT, RANK>, but the actual
// DT and RANK may differ depending on member.
[[maybe_unused]] const auto& desc,
[[maybe_unused]] void* T::*ptr) {})
};
};
namespace detail {
/// The default alignment between tensors allocated separately
/// by `UniqueTensors`. This should be large enough to accomodate
/// any type. hipMalloc returns an alignment of 256 by default.
constexpr size_t TENSOR_ALIGNMENT = 256;
/// @brief Common type for automatically managing memory of sets of tensors.
///
/// This type implements the automatic memory management logic for `Inputs` and
/// `Outputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
/// @tparam Tensors The `Inputs` or `Outputs` type corresponding to `SIGNATURE`.
template <auto SIGNATURE, typename Tensors>
requires TensorReflectable<Tensors, SIGNATURE>
struct UniqueTensors
{
/// @brief Allocate tensors.
///
/// This function computes the total size of memory to allocate according to
/// the tensors in `args`, and then allocates it as a continuous buffer.
///
/// @param args The run-time arguments of the operation.
explicit UniqueTensors(const Args<SIGNATURE>& args)
{
// First compute the total size of all tensors combined
size_t total_size = 0;
Tensors::reflect(args,
[&, this]([[maybe_unused]] std::string_view name,
const auto& desc,
[[maybe_unused]] void* Tensors::*ptr) {
total_size = align_fwd(total_size, TENSOR_ALIGNMENT);
total_size += desc.get_element_space_size_in_bytes();
});
data_ = alloc_buffer(total_size);
// Now assign the pointers based on the same offsets that
// we computed in the first loop.
size_t offset = 0;
Tensors::reflect(args,
[&, this]([[maybe_unused]] std::string_view name,
const auto& desc,
[[maybe_unused]] void* Tensors::*ptr) {
offset = align_fwd(offset, TENSOR_ALIGNMENT);
tensors_.*ptr = data_.get() + offset;
offset += desc.get_element_space_size_in_bytes();
});
}
/// @brief Return raw `Inputs` or `Outputs` type.
///
/// @see ValidUniqueInputs
/// @see ValidUniqueOutputs
Tensors get() const { return tensors_; }
private:
/// Owning pointer of input memory
DeviceBuffer data_;
/// Struct with pointers to each tensor. Stored here so that we
/// don't need to keep recomputing it.
Tensors tensors_;
};
} // namespace detail
/// @brief Implementation of `UniqueInputs` for `Inputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize for.
///
/// @see UniqueInputs
template <auto SIGNATURE>
requires TensorReflectable<Inputs<SIGNATURE>, SIGNATURE>
struct UniqueInputs<SIGNATURE> : detail::UniqueTensors<SIGNATURE, Inputs<SIGNATURE>>
{
using detail::UniqueTensors<SIGNATURE, Inputs<SIGNATURE>>::UniqueTensors;
};
/// @brief Implementation of `UniqueOutputs` for `Outputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize for.
///
/// @see UniqueOutputs
template <auto SIGNATURE>
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
struct UniqueOutputs<SIGNATURE> : detail::UniqueTensors<SIGNATURE, Outputs<SIGNATURE>>
{
using detail::UniqueTensors<SIGNATURE, Outputs<SIGNATURE>>::UniqueTensors;
};
/// @brief Implementation of `alloc_inputs` for `Inputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize for.
///
/// @param args The run-time arguments of the operation.
///
/// @see alloc_inputs
template <auto SIGNATURE>
requires TensorReflectable<Inputs<SIGNATURE>, SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
{
static_assert(ValidUniqueInputs<SIGNATURE>, "sanity check");
return UniqueInputs<SIGNATURE>(args);
}
/// @brief Implementation of `alloc_outputs` for `Outputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize for.
///
/// @param args The run-time arguments of the operation.
///
/// @see alloc_outputs
template <auto SIGNATURE>
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
{
static_assert(ValidUniqueOutputs<SIGNATURE>, "sanity check");
return UniqueOutputs<SIGNATURE>(args);
}
/// @brief Implementation of `validate` for `Outputs` that support reflection.
///
/// @tparam SIGNATURE The signature to specialize for.
///
/// @param args The run-time arguments of the operation.
/// @param actual The actual results, the results of the operation to-be-tested.
/// @param expected The expected results, the results of the reference implementation.
///
/// @see alloc_outputs
template <auto SIGNATURE>
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
ValidationReport
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected)
{
ValidationReport report;
Outputs<SIGNATURE>::reflect(
args, [&](std::string_view name, const auto& desc, void* Outputs<SIGNATURE>::*ptr) {
report.check(name, desc, actual.*ptr, expected.*ptr);
});
return report;
}
} // namespace ck_tile::builder::test

View File

@@ -39,7 +39,7 @@ constexpr size_t data_type_sizeof(DataType data_type)
case DataType::FP8: return 1;
case DataType::BF8: return 1;
case DataType::FP64: return 8;
case DataType::INT32: return 4;
case DataType::I32: return 4;
case DataType::I8: return 1;
case DataType::I8_I8: return 2;
case DataType::U8: return 1;

View File

@@ -7,7 +7,6 @@
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/type_convert.hpp"
#include <string_view>
#include <vector>

View File

@@ -24,7 +24,7 @@ enum class DataType
FP8,
BF8,
FP64,
INT32,
I32,
I8,
I8_I8,
U8
@@ -192,8 +192,8 @@ enum class TileConvSpecialization
FILTER_3x3
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
// Enums for the convolution specializations.
enum class ConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
@@ -202,22 +202,6 @@ enum class ConvFwdSpecialization
ODD_C
};
// Enums for the backward data convolution specialization.
enum class ConvBwdDataSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
};
// Enums for the backward weight convolution specialization.
enum class ConvBwdWeightSpecialization
{
DEFAULT,
FILTER_1X1_STRIDE1_PAD0,
FILTER_1X1_PAD0,
ODD_C,
};
// Enums for the Gemm padding.
enum class GemmPadding
{
@@ -249,11 +233,13 @@ enum class PipelineScheduler
enum class ConvAlgorithmSpecialization
{
LARGE_TENSOR,
REFERENCE // GPU reference implementation for validation
REFERENCE, // GPU reference implementation for validation,
TWO_STAGE,
MULTIPLE_D
};
// toString methods for enum classes
inline std::string_view toString(DataType dt)
// to_string methods for enum classes
inline std::string_view to_string(DataType dt)
{
using enum DataType;
switch(dt)
@@ -267,7 +253,7 @@ inline std::string_view toString(DataType dt)
case FP8: return "FP8";
case BF8: return "BF8";
case FP64: return "FP64";
case INT32: return "INT32";
case I32: return "I32";
case I8: return "I8";
case I8_I8: return "I8_I8";
case U8: return "U8";
@@ -276,7 +262,7 @@ inline std::string_view toString(DataType dt)
}
}
inline std::string_view toString(ConvDirection dir)
inline std::string_view to_string(ConvDirection dir)
{
using enum ConvDirection;
switch(dir)
@@ -288,7 +274,7 @@ inline std::string_view toString(ConvDirection dir)
}
}
inline std::string_view toString(ElementwiseOperation op)
inline std::string_view to_string(ElementwiseOperation op)
{
using enum ElementwiseOperation;
switch(op)
@@ -332,7 +318,7 @@ inline std::string_view toString(ElementwiseOperation op)
}
}
inline std::string_view toString(PipelineVersion ver)
inline std::string_view to_string(PipelineVersion ver)
{
using enum PipelineVersion;
switch(ver)
@@ -347,7 +333,7 @@ inline std::string_view toString(PipelineVersion ver)
}
}
inline std::string_view toString(GemmSpecialization spec)
inline std::string_view to_string(GemmSpecialization spec)
{
using enum GemmSpecialization;
switch(spec)
@@ -372,9 +358,9 @@ inline std::string_view toString(GemmSpecialization spec)
}
}
inline std::string_view toString(ConvFwdSpecialization spec)
inline std::string_view to_string(ConvSpecialization spec)
{
using enum ConvFwdSpecialization;
using enum ConvSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
@@ -386,31 +372,7 @@ inline std::string_view toString(ConvFwdSpecialization spec)
}
}
inline std::string_view toString(ConvBwdDataSpecialization spec)
{
using enum ConvBwdDataSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
default: return "Unknown";
}
}
inline std::string_view toString(ConvBwdWeightSpecialization spec)
{
using enum ConvBwdWeightSpecialization;
switch(spec)
{
case DEFAULT: return "DEFAULT";
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
case ODD_C: return "ODD_C";
default: return "Unknown";
}
}
inline std::string_view toString(GemmPadding padding)
inline std::string_view to_string(GemmPadding padding)
{
using enum GemmPadding;
switch(padding)
@@ -435,7 +397,7 @@ inline std::string_view toString(GemmPadding padding)
}
}
inline std::string_view toString(PipelineScheduler sched)
inline std::string_view to_string(PipelineScheduler sched)
{
using enum PipelineScheduler;
switch(sched)
@@ -447,7 +409,7 @@ inline std::string_view toString(PipelineScheduler sched)
}
}
inline std::string_view toString(TensorLayout layout)
inline std::string_view to_string(TensorLayout layout)
{
using enum TensorLayout;
switch(layout)
@@ -503,63 +465,46 @@ inline std::string_view toString(TensorLayout layout)
}
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); }
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << to_string(dt); }
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); }
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
{
return os << to_string(dir);
}
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
{
return os << toString(op);
return os << to_string(op);
}
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
{
return os << toString(ver);
return os << to_string(ver);
}
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
{
return os << toString(spec);
return os << to_string(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
{
return os << toString(spec);
}
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
{
return os << toString(spec);
return os << to_string(spec);
}
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
{
return os << toString(padding);
return os << to_string(padding);
}
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
{
return os << toString(sched);
return os << to_string(sched);
}
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
{
return os << toString(layout);
}
// ostream operator overload for std::variant of convolution specializations
inline std::ostream& operator<<(std::ostream& os,
const std::variant<ConvFwdSpecialization,
ConvBwdDataSpecialization,
ConvBwdWeightSpecialization>& spec)
{
std::visit([&os](const auto& s) { os << s; }, spec);
return os;
return os << to_string(layout);
}
} // namespace ck_tile::builder

View File

@@ -83,11 +83,14 @@ add_ck_builder_test(test_ckb_conv_builder
unit_tensor_foreach.cpp
unit_error.cpp
unit_validation.cpp
unit_debug.cpp
unit_conv_fwd_testing.cpp
unit_conv_elementwise_op.cpp
unit_conv_tensor_layout.cpp
unit_conv_tensor_type.cpp
unit_conv_thread_block.cpp
unit_conv_tuning_params.cpp)
target_link_libraries(test_ckb_conv_builder PRIVATE utility)
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
@@ -121,7 +124,7 @@ add_ck_builder_test(test_ckb_conv_description
# Verifies that GetInstanceString() methods and other functions produce valid kernel code.
# Tests various convolution types:
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
# - Backward weight group convolution (XDL)
# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants)
# Requires kernel compilation to validate the generated strings through the base class.
set(INSTANCE_STRING_TESTS
@@ -164,10 +167,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp)
)
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
set(BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
)
if (CK_USE_WMMA)
list(APPEND BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp
conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp
)
endif()
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
add_ck_builder_test(test_ckb_build_bwd_data_instances
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
)
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)
################################################################################
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
@@ -221,6 +249,8 @@ endforeach()
set(CKB_REGRESSION_TESTS
test_ckb_instance_string
test_ckb_build_fwd_instances
test_ckb_build_bwd_weight_instances
test_ckb_build_bwd_data_instances
test_ckb_testing_utils
# test_ckb_factory_grouped_convolution_forward_convscale
# test_ckb_factory_grouped_convolution_forward_scaleadd_ab

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{}
.with_thread_block(cku::ThreadBlock_256_128x128x16)
.with_bwd_specialization(cku::ConvSpecialization::DEFAULT)
.with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1)
.with_dl_thread_cluster(cku::DlThreadCluster_8x2)
.with_dl_transfer(cku::DlTransfer5D);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DBf16_DL, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Dl",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough"});
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 3,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNDHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3",
expected_transfer_parameters,
"Default",
"GNDHWC,GKZYXC,GNDHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16>"}); // check compute types
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16>"}); // check compute types
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCHW}},
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NGKHW}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_num_conv_groups_to_merge(2)
.with_transpose_params(2, 2);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3",
expected_transfer_parameters,
"Default",
"NGCHW,GKYXC,NGKHW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v1",
"fp16,fp16,2,2>"}); // Check compute types and transpose params.
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(2)
.with_transpose_params(2, 4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"Intrawave,v2", // pipeline versions
"bf16,bf16,2,4>"}); // compute types and transpose params
}

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCDHW}},
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = NGKDHW}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffle",
expected_transfer_parameters,
"Default",
"NGCDHW,GKZYXC,NGKDHW",
"PassThrough,PassThrough,PassThrough",
"v1"});
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
.with_transpose_params(4, 4);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v1",
"bf16,bf16,4,4>"});
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
.with_thread_block(cku::ThreadBlock_256_128x128x8)
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::BwdTransfer_4x64x1)
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
.with_transpose_params(2, 2);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
expected_transfer_parameters,
"Default",
"GNHWC,GKYXC,GNHWK",
"PassThrough,PassThrough,PassThrough",
"fp16,fp16,2,2>"}); // check compute types and transpose params
}

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
.with_thread_block(cku::ThreadBlock_64_32x32x32)
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v2"});
}

View File

@@ -30,11 +30,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v2_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -22,18 +22,20 @@ TEST(FwdConvInstances,
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
.direction = FORWARD,
.data_type = I8,
.accumulation_data_type = INT32,
.accumulation_data_type = I32,
.input = {.config = {.layout = GNWC}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = GNWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
.with_thread_block(FwdThreadBlock_128_64x64x64)
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
.with_transfer(FwdTransfer_4x32x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
.with_thread_block(ThreadBlock_128_64x64x64)
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
.with_transfer(Transfer_4x32x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(2)
.with_gridwise_gemm_pipeline(PipelineVersion::V1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -64,10 +64,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v5_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -32,11 +32,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_64_64x32x32)
.with_thread_block(ThreadBlock_64_64x32x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(FwdThreadBlock_256_128x128x16)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)
.with_dl_transfer(DlFwdTransfer);
.with_dl_transfer(DlTransfer4D);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
expected_transfer_parameters,
"Default",
@@ -59,16 +60,17 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
.with_thread_block(FwdThreadBlock_256_128x128x16)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_thread_block(ThreadBlock_256_128x128x16)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
.with_dl_thread_cluster(DlThreadCluster_8x2)
.with_dl_transfer(DlFwdTransfer);
.with_dl_transfer(DlTransfer4D);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
expected_transfer_parameters,
"Filter1x1Pad0",

View File

@@ -25,11 +25,11 @@ constexpr auto SIGNATURE =
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
.with_thread_block(cku::ThreadBlock_256_256x256x32)
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(cku::FwdTransfer_4x64x1)
.with_specializations(ckb::ConvFwdSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_transfer(cku::Transfer_4x64x1)
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
ckb::GemmSpecialization::MNKPadding)
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;

View File

@@ -26,11 +26,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_256_256x128x32)
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
.with_transfer(FwdTransfer_4x64x1_fp8)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
.with_transfer(Transfer_4x64x1_fp8)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -25,14 +25,13 @@ TEST(FwdConvInstances,
.output = {.config = {.layout = GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::DEFAULT,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
.with_thread_block(ThreadBlock_256_256x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
@@ -62,14 +61,14 @@ TEST(
.output = {.config = {.layout = GNHWK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
.with_thread_block(FwdThreadBlock_128_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x16x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
.with_thread_block(ThreadBlock_128_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(Transfer_4x16x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_num_conv_groups_to_merge(1);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v3_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_128x128x32)
.with_thread_block(ThreadBlock_256_128x128x32)
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v4_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
.with_thread_block(FwdThreadBlock_256_256x256x32)
.with_thread_block(ThreadBlock_256_256x256x32)
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
.with_transfer(FwdTransfer_4x64x1)
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_transfer(Transfer_4x64x1)
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
GemmSpecialization::MNKPadding)
.with_block_gemm(BlockGemmDesc_v1_intrawave);
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;

View File

@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);

View File

@@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
@@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
EXPECT_EQ(Traits::conv_specialization,
ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
}
// ============================================================================

View File

@@ -8,26 +8,27 @@ namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr ConvSignature BwdDataConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
constexpr auto BwdDataConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_data",
"fp16",

View File

@@ -8,26 +8,27 @@ namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr ConvSignature BwdWeightConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
constexpr auto BwdWeightConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",

View File

@@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});

View File

@@ -28,18 +28,31 @@ struct ThreadBlock
};
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
// Describe gridwise XDL GEMM parameters.
struct GridwiseXdlGemm
struct XdlParams
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
static_assert(ckb::GridwiseXdlGemmDescriptor<XdlParams>);
// Describe gridwise XDL GEMM parameters.
struct GridwiseFwdXdlGemm
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
XdlParams xdl_params;
};
static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
struct GridwiseBwdXdlGemm
{
size_t k1 = 0;
XdlParams xdl_params;
};
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
// Describe gridwise WMMA GEMM parameters.
struct GridwiseWmmaGemm
@@ -49,25 +62,36 @@ struct GridwiseWmmaGemm
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
PipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
struct BlockGemmPipeline
{
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
static_assert(ckb::BlockGemmPipelineDescriptor<BlockGemmPipeline>);
// Describe Aand B block transfer thread cluster lengths.
template <size_t ThreadSliceLength = 3>
struct BlockTransfer
{
size_t k0;
size_t m_n;
size_t k1;
size_t k_batch_size;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
// Specialization for ThreadSliceLength == 3
template <>
struct BlockTransfer<3>
{
size_t k0;
size_t m_n;
size_t k1;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<3>, 3>);
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<4>, 4>);
// Describe C block transfer thread cluster lengths.
struct ThreadCluster
@@ -97,31 +121,35 @@ struct Epilogue
};
static_assert(EpilogueDescriptor<Epilogue>);
template <size_t ThreadSliceLength = 3>
struct AccessOrder
{
std::array<size_t, 3> order;
std::array<size_t, ThreadSliceLength> order;
};
static_assert(AccessOrderDescriptor<AccessOrder>);
static_assert(AccessOrderDescriptor<AccessOrder<>>);
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
struct TransferAB
template <size_t ThreadSliceLength = 3>
struct InputTransfer
{
BlockTransfer block_transfer;
BlockTransfer<ThreadSliceLength> block_transfer;
LdsTransfer lds_transfer;
AccessOrder block_transfer_access_order;
AccessOrder src_access_order;
AccessOrder<ThreadSliceLength> block_transfer_access_order;
AccessOrder<ThreadSliceLength> src_access_order;
};
struct TransferC
struct OutputTransfer
{
ThreadCluster thread_cluster_dims;
Epilogue epilogue;
};
struct TransferABC
template <size_t ThreadSliceLength = 3>
struct Transfer
{
TransferAB a;
TransferAB b;
TransferC c;
InputTransfer<ThreadSliceLength> a;
InputTransfer<ThreadSliceLength> b;
OutputTransfer c;
};
// DL-specific descriptors
@@ -142,17 +170,19 @@ struct DlThreadCluster
};
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
template <size_t D = 4>
struct DlBlockTransfer
{
std::array<size_t, 4> thread_slice_lengths;
std::array<size_t, 4> thread_cluster_lengths;
std::array<size_t, 4> thread_cluster_arrange_order;
std::array<size_t, 4> src_access_order;
std::array<size_t, 4> src_vector_tensor_lengths;
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
std::array<size_t, 4> dst_vector_tensor_lengths;
std::array<size_t, D> thread_slice_lengths;
std::array<size_t, D> thread_cluster_lengths;
std::array<size_t, D> thread_cluster_arrange_order;
std::array<size_t, D> src_access_order;
std::array<size_t, D> src_vector_tensor_lengths;
std::array<size_t, D> src_vector_tensor_contiguous_dim_order;
std::array<size_t, D> dst_vector_tensor_lengths;
};
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
static_assert(ckb::DlBlockTransferDescriptor4D<DlBlockTransfer<4>>);
static_assert(ckb::DlBlockTransferDescriptor5D<DlBlockTransfer<5>>);
struct DlEpilogue
{
@@ -169,9 +199,14 @@ struct ThreadBlock_
ThreadBlock thread_block;
};
struct XdlGemm_
struct FwdXdlGemm_
{
GridwiseXdlGemm gridwise_gemm;
GridwiseFwdXdlGemm gridwise_gemm;
};
struct BwdXdlGemm_
{
GridwiseBwdXdlGemm gridwise_gemm;
};
struct WmmaGemm_
@@ -179,27 +214,48 @@ struct WmmaGemm_
GridwiseWmmaGemm gridwise_gemm;
};
template <size_t ThreadSliceLength = 3>
struct Transfer_
{
TransferABC transfer;
Transfer<ThreadSliceLength> transfer;
};
struct ConvSpecialization_
struct ConvSpecializationFwd_
{
ConvFwdSpecialization fwd_specialization;
ConvSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
};
struct ConvSpecializationBwdWeight_
{
ConvSpecialization bwd_weight_specialization;
};
struct Prefetch_
{
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
PipelineScheduler loop_scheduler;
};
struct TransposeParams_
{
size_t max_transpose_transfer_src_scalar_per_vector{1};
size_t max_transpose_transfer_dst_scalar_per_vector{1};
};
struct GemmBatchOptions_
{
size_t num_conv_groups_to_merge{1};
};
struct BlockGemm_
{
BlockGemm block_gemm;
BlockGemmPipeline block_gemm_pipeline;
};
struct GridGemm_
{
PipelineVersion pipeline_version;
};
struct DlThreadConfig_
@@ -212,33 +268,34 @@ struct DlThreadCluster_
DlThreadCluster thread_cluster;
};
struct DlBlockTransferAB
template <size_t Dim = 4>
struct DlTransfer
{
DlBlockTransfer block_transfer;
};
struct DlBlockTransferC
{
DlEpilogue epilogue;
};
struct DlTransferABC
{
DlBlockTransferAB a;
DlBlockTransferAB b;
DlBlockTransferC c;
DlBlockTransfer<Dim> a;
DlBlockTransfer<Dim> b;
DlEpilogue c;
};
template <size_t Dim = 4>
struct DlTransfer_
{
DlTransferABC transfer;
DlTransfer<Dim> transfer;
};
// Specialization wrapper for large tensor support
template <typename BaseAlgorithm>
struct LargeTensorWrapper
struct TwoStageSpecialization_
{
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::TWO_STAGE;
};
struct MultipleDSpecialization_
{
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::MULTIPLE_D;
};
struct LargeTensorSpecialization_
{
BaseAlgorithm base_algorithm;
static constexpr ConvAlgorithmSpecialization specialization =
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
@@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components...
constexpr auto with_gemm_config(const GemmConfig& gemm) const
{
auto result = *this;
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
if constexpr(std::is_base_of_v<FwdXdlGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
else if constexpr(std::is_base_of_v<BwdXdlGemm_, ConvAlgorithmTemplate>)
{
result.gridwise_gemm = gemm;
}
@@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components...
{
result.gridwise_gemm = gemm;
}
else
{
static_assert(false, "Unrecognized GemmConfig type");
}
return result;
}
template <typename T>
constexpr auto with_transfer(const T& t) const
{
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<Transfer_<3>, ConvAlgorithmTemplate> ||
std::is_base_of_v<Transfer_<4>, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
GemmSpecialization gemm_spec) const
constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec,
GemmSpecialization gemm_spec) const
{
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<ConvSpecializationFwd_, ConvAlgorithmTemplate>);
auto result = *this;
result.fwd_specialization = fwd_spec;
result.gemm_specialization = gemm_spec;
return result;
}
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
size_t groups_to_merge,
PipelineScheduler scheduler) const
constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const
{
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
auto result = *this;
result.bwd_weight_specialization = bwd_spec;
return result;
}
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
{
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
result.num_groups_to_merge = groups_to_merge;
result.loop_scheduler = scheduler;
return result;
}
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
size_t max_dst_scalar_per_vector) const
{
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
auto result = *this;
result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector;
result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector;
return result;
}
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
{
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
auto result = *this;
result.num_conv_groups_to_merge = num_groups_to_merge;
return result;
}
template <typename BG>
constexpr auto with_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
auto result = *this;
result.block_gemm_pipeline = bg;
return result;
}
constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const
{
static_assert(std::is_base_of_v<GridGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.pipeline_version = plv;
return result;
}
@@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components...
template <typename T>
constexpr auto with_dl_transfer(const T& t) const
{
static_assert(std::is_base_of_v<DlTransfer_, ConvAlgorithmTemplate>);
static_assert(std::is_base_of_v<DlTransfer_<4>, ConvAlgorithmTemplate> ||
std::is_base_of_v<DlTransfer_<5>, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
@@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components...
}
};
// Algorithm types
// Fwd algorithm types
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
Prefetch_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationFwd_,
GridGemm_,
Prefetch_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
ConvAlgorithmTemplate<ThreadBlock_,
ConvSpecialization_,
ConvSpecializationFwd_,
DlThreadConfig_,
DlThreadCluster_,
DlTransfer_>;
DlTransfer_<>>;
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
ConvAlgorithmTemplate<ThreadBlock_,
FwdXdlGemm_,
Transfer_<>,
ConvSpecializationFwd_,
Prefetch_,
GemmBatchOptions_,
LargeTensorSpecialization_>;
// CK Tile algorithm
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
@@ -488,4 +609,77 @@ struct ConvAlgorithm_Reference
// GPU reference uses simple algorithm, no tile configuration needed
};
// Bwd weight algorithm types
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<4>,
ConvSpecializationBwdWeight_,
TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_,
GemmBatchOptions_,
TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,
DlThreadConfig_,
DlThreadCluster_,
DlTransfer_<5>,
ConvSpecializationBwdWeight_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
BwdXdlGemm_,
Transfer_<4>,
ConvSpecializationBwdWeight_,
MultipleDSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
TransposeParams_,
GemmBatchOptions_,
TwoStageSpecialization_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
GridGemm_,
Prefetch_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
ConvAlgorithmTemplate<ThreadBlock_,
WmmaGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_,
MultipleDSpecialization_>;
} // namespace ck_tile::builder::test

View File

@@ -120,14 +120,12 @@ struct DefaultAlgorithm
ckb::test::ThreadBlock thread_block{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 16,
.n_per_xdl = 16,
.m_xdl_per_wave = 8,
.n_xdl_per_wave = 8};
ckb::test::GridwiseFwdXdlGemm gridwise_gemm{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}};
ckb::test::TransferABC transfer{
ckb::test::Transfer<> transfer{
.a =
{
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
@@ -161,10 +159,11 @@ struct DefaultAlgorithm
},
};
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
.scheduler =
ckb::PipelineScheduler::INTRAWAVE};
};
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);

View File

@@ -4,6 +4,7 @@
#include "impl/conv_signature_types.hpp"
#include "testing_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
@@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::NotNull;
constexpr auto SIGNATURE =
@@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
static_assert(ckt::TensorReflectable<Inputs, SIGNATURE>);
static_assert(ckt::TensorReflectable<Outputs, SIGNATURE>);
TEST(ConvFwdTesting, MakeDescriptors)
{
@@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc)
EXPECT_THAT(inputs.get().weight, NotNull());
EXPECT_THAT(outputs.get().output, NotNull());
}
TEST(ConvFwdTesting, Validate)
{
auto a = alloc_outputs(ARGS);
auto b = alloc_outputs(ARGS);
// Positive test
{
ckt::Outputs<SIGNATURE>::reflect(
ARGS,
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
EXPECT_THAT(report.get_errors().size(), Eq(0));
}
// Negative test
{
size_t field_count = 0;
ckt::Outputs<SIGNATURE>::reflect(
ARGS,
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
++field_count;
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
EXPECT_THAT(report.get_errors().size(), Eq(field_count));
}
}

View File

@@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = GNWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
.weight = {.config = {.layout = GKCX}},
.output = {.config = {.layout = NGKW}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = GNHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
.weight = {.config = {.layout = GKCYX}},
.output = {.config = {.layout = NGKHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
.weight = {.config = {.layout = GKCZYX}},
.output = {.config = {.layout = NGKDHW}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = NDHWGK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
.weight = {.config = {.layout = GKZYXC}},
.output = {.config = {.layout = GNDHWK}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
@@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
@@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
@@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
MockAuxiliaryTensorConfig{.layout = GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 2);
using ExpectedType =
@@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
MockAuxiliaryTensorConfig{.layout = GC},
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
EXPECT_EQ(AuxLayouts::Size, 3);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
@@ -349,7 +349,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
@@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
MockAuxiliaryTensorConfig{.layout = GC}};
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3>;
EXPECT_EQ(AuxLayouts::Size, 1);
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
@@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
using ExpectedDsLayout =
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
@@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
.operation =
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
.operation = OutputOp{.elementwise_operation =
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));

View File

@@ -27,7 +27,7 @@ TEST(ConvTensorType, Exhaustive)
case FP32: EXPECT_TRUE((check_same<FP32, float>)); break;
case FP16: EXPECT_TRUE((check_same<FP16, ck::half_t>)); break;
case BF16: EXPECT_TRUE((check_same<BF16, ck::bhalf_t>)); break;
case INT32: EXPECT_TRUE((check_same<INT32, uint32_t>)); break;
case I32: EXPECT_TRUE((check_same<I32, uint32_t>)); break;
case FP8: EXPECT_TRUE((check_same<FP8, ck::f8_t>)); break;
case I8: EXPECT_TRUE((check_same<I8, int8_t>)); break;
case U8: EXPECT_TRUE((check_same<U8, uint8_t>)); break;

View File

@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
} block_gemm;
} block_gemm_pipeline;
} kAlgorithm;
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
{
constexpr struct Algorithm
{
struct GridwiseGemm
{
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} gridwise_gemm;
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
} kAlgorithm;
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
@@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
{
constexpr struct Algorithm
{
ckb::ConvFwdSpecialization fwd_specialization =
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
ckb::ConvSpecialization fwd_specialization =
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
} kAlgorithm;
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();

View File

@@ -0,0 +1,464 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include "ck_tile/builder/testing/debug.hpp"
#include "testing_utils.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <sstream>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ck_tile::test::StringEqWithDiff;
using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::Gt;
TEST(Debug, PrintDescriptor)
{
auto desc =
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{10, 11, 12}, ckt::PackedRightLayout{});
std::stringstream ss;
ckt::print_descriptor("test", desc, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Descriptor \"test\":\n"
" data type: I32\n"
" size: 1'320 elements\n"
" space: 1'320 elements (5'280 bytes)\n"
" lengths: [10, 11, 12]\n"
" strides: [132, 12, 1]\n"
" packed: yes\n"));
// Make sure that the stream locale does not leak.
ss.str("");
ss << 1000;
EXPECT_THAT(ss.str(), StringEqWithDiff("1000"));
}
TEST(Debug, LimitedForeach)
{
{
std::vector<size_t> values;
size_t delim_count = 0;
ckt::detail::limited_foreach(
10,
2,
[&](auto i) { values.push_back(i); },
[&](auto skip_count) {
++delim_count;
EXPECT_THAT(skip_count, Eq(10 - 2));
});
EXPECT_THAT(values, ElementsAreArray({0, 9}));
EXPECT_THAT(delim_count, Eq(1));
}
{
std::vector<size_t> values;
size_t delim_count = 0;
ckt::detail::limited_foreach(
100,
9,
[&](auto i) { values.push_back(i); },
[&](auto skip_count) {
++delim_count;
EXPECT_THAT(skip_count, Eq(100 - 9));
});
EXPECT_THAT(values, ElementsAreArray({0, 1, 2, 3, 4, 96, 97, 98, 99}));
EXPECT_THAT(delim_count, Eq(1));
}
{
size_t call_count = 0;
size_t delim_count = 0;
ckt::detail::limited_foreach(
50,
100,
[&](auto i) {
EXPECT_THAT(i, Eq(call_count));
++call_count;
},
[&]([[maybe_unused]] auto skip_count) { ++delim_count; });
EXPECT_THAT(call_count, Eq(50));
EXPECT_THAT(delim_count, Eq(0));
}
}
TEST(Debug, PrintTensor0D)
{
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 123; });
std::stringstream ss;
ckt::print_tensor("0D", desc, a.get(), {}, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"0D\": shape = []\n"
" 123\n"));
}
TEST(Debug, PrintTensor1D)
{
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{44}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i % 7; });
std::stringstream ss;
ckt::print_tensor("1D", desc, a.get(), {}, ss);
// Note: output does not involve the size of the matrix separator fields,
// since these are not printed.
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"1D\": shape = [44]\n"
" 0 1 2 3 4 ... 4 5 6 0 1\n"));
}
TEST(Debug, PrintTensor4D)
{
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{100, 110, 120, 130},
ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i; });
std::stringstream ss;
ckt::print_tensor("4D",
desc,
a.get(),
{
// Reduce default limits to have smaller output here.
// That also tests that we can configure these (to some
// extent).
.col_limit = 4,
.row_limit = 4,
.slice_limit = 4,
},
ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"4D\": shape = [100, 110, 120, 130]\n"
"Tensor \"4D\", slice [0, 0, :, :]\n"
" 0 1 ... 128 129\n"
" 130 131 ... 258 259\n"
" ... ... ... ... ...\n"
" 15340 15341 ... 15468 15469\n"
" 15470 15471 ... 15598 15599\n"
"\n"
"Tensor \"4D\", slice [0, 1, :, :]\n"
" 15600 15601 ... 15728 15729\n"
" 15730 15731 ... 15858 15859\n"
" ... ... ... ... ...\n"
" 30940 30941 ... 31068 31069\n"
" 31070 31071 ... 31198 31199\n"
"\n"
"(skipping 10'996 slices...)\n"
"\n"
"Tensor \"4D\", slice [99, 108, :, :]\n"
" 171568800 171568801 ... 171568928 171568929\n"
" 171568930 171568931 ... 171569058 171569059\n"
" ... ... ... ... ...\n"
" 171584140 171584141 ... 171584268 171584269\n"
" 171584270 171584271 ... 171584398 171584399\n"
"\n"
"Tensor \"4D\", slice [99, 109, :, :]\n"
" 171584400 171584401 ... 171584528 171584529\n"
" 171584530 171584531 ... 171584658 171584659\n"
" ... ... ... ... ...\n"
" 171599740 171599741 ... 171599868 171599869\n"
" 171599870 171599871 ... 171599998 171599999\n"));
}
TEST(Debug, PrintTensorCustomConfig)
{
auto desc =
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{10, 10, 10}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 101 % 77; });
std::stringstream ss;
ckt::print_tensor("CustomConfig",
desc,
a.get(),
{
// Reduce default limits to have smaller output here.
// That also tests that we can configure these.
.col_limit = 4,
.row_limit = 2,
.slice_limit = 6,
// Try with different sizes to make sure that the alignment
// is still correct after changing these.
.row_prefix = ">>>>",
.row_field_sep = "|||||",
.row_skip_val = "-------",
.matrix_row_skip_val = "&&&&&&&&",
},
ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"CustomConfig\": shape = [10, 10, 10]\n"
"Tensor \"CustomConfig\", slice [0, :, :]\n"
">>>>||||| 0||||| 24|||||-------||||| 38||||| 62\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 4||||| 28|||||-------||||| 42||||| 66\n"
"\n"
"Tensor \"CustomConfig\", slice [1, :, :]\n"
">>>>||||| 13||||| 37|||||-------||||| 51||||| 75\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 17||||| 41|||||-------||||| 55||||| 2\n"
"\n"
"Tensor \"CustomConfig\", slice [2, :, :]\n"
">>>>||||| 26||||| 50|||||-------||||| 64||||| 11\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 30||||| 54|||||-------||||| 68||||| 15\n"
"\n"
"(skipping 4 slices...)\n"
"\n"
"Tensor \"CustomConfig\", slice [7, :, :]\n"
">>>>||||| 14||||| 38|||||-------||||| 52||||| 76\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 18||||| 42|||||-------||||| 56||||| 3\n"
"\n"
"Tensor \"CustomConfig\", slice [8, :, :]\n"
">>>>||||| 27||||| 51|||||-------||||| 65||||| 12\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 31||||| 55|||||-------||||| 69||||| 16\n"
"\n"
"Tensor \"CustomConfig\", slice [9, :, :]\n"
">>>>||||| 40||||| 64|||||-------||||| 1||||| 25\n"
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
">>>>||||| 44||||| 68|||||-------||||| 5||||| 29\n"));
}
TEST(Debug, PrintTensorUnlimitedMatrix)
{
// To limit the output of the test, split the "unlimited" test up into one for the
// matrices and one for the slices.
const ckt::Extent shape = ckt::Extent{12, 12};
const ckt::TensorPrintConfig default_config;
// The shape should be larger than the default, otherwise this test doesn't make
// any sense.
ASSERT_THAT(shape[1], Gt(default_config.col_limit));
ASSERT_THAT(shape[2], Gt(default_config.row_limit));
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i ^ 0xF; });
std::stringstream ss;
ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"UnlimitedConfig\": shape = [12, 12]\n"
" 15 14 13 12 11 10 9 8 7 6 5 4\n"
" 3 2 1 0 31 30 29 28 27 26 25 24\n"
" 23 22 21 20 19 18 17 16 47 46 45 44\n"
" 43 42 41 40 39 38 37 36 35 34 33 32\n"
" 63 62 61 60 59 58 57 56 55 54 53 52\n"
" 51 50 49 48 79 78 77 76 75 74 73 72\n"
" 71 70 69 68 67 66 65 64 95 94 93 92\n"
" 91 90 89 88 87 86 85 84 83 82 81 80\n"
" 111 110 109 108 107 106 105 104 103 102 101 100\n"
" 99 98 97 96 127 126 125 124 123 122 121 120\n"
" 119 118 117 116 115 114 113 112 143 142 141 140\n"
" 139 138 137 136 135 134 133 132 131 130 129 128\n"));
}
TEST(Debug, PrintTensorUnlimitedSlices)
{
// To limit the output of the test, split the "unlimited" test up into one for the
// matrices and one for the slices.
const ckt::Extent shape = ckt::Extent{13, 1, 1};
const ckt::TensorPrintConfig default_config;
// The shape should be larger than the default, otherwise this test doesn't make
// any sense.
ASSERT_THAT(shape[0], Gt(default_config.slice_limit));
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 3; });
std::stringstream ss;
ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"UnlimitedConfig\": shape = [13, 1, 1]\n"
"Tensor \"UnlimitedConfig\", slice [0, :, :]\n"
" 0\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [1, :, :]\n"
" 3\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [2, :, :]\n"
" 6\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [3, :, :]\n"
" 9\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [4, :, :]\n"
" 12\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [5, :, :]\n"
" 15\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [6, :, :]\n"
" 18\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [7, :, :]\n"
" 21\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [8, :, :]\n"
" 24\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [9, :, :]\n"
" 27\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [10, :, :]\n"
" 30\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [11, :, :]\n"
" 33\n"
"\n"
"Tensor \"UnlimitedConfig\", slice [12, :, :]\n"
" 36\n"));
}
TEST(Debug, PrintTensorFP32)
{
auto desc =
ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(1.9999, i); });
std::stringstream ss;
ckt::print_tensor("FP32", desc, a.get(), {}, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"FP32\": shape = [5, 5]\n"
" 1.000 2.000 4.000 7.999 15.997\n"
" 31.992 63.981 127.955 255.898 511.770\n"
" 1023.488 2046.874 4093.543 8186.677 16372.535\n"
" 32743.432 65483.590 130960.633 261908.172 523790.156\n"
" 1047527.938 2094951.125 4189692.750 8378966.500 16757095.000\n"));
}
TEST(Debug, PrintTensorBF16)
{
auto desc =
ckt::make_descriptor<ckb::DataType::BF16>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(1.2345678f * i); });
std::stringstream ss;
ckt::print_tensor("BF16", desc, a.get(), {}, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"BF16\": shape = [5, 5]\n"
" 0.000 1.234 2.469 3.703 4.938\n"
" 6.188 7.406 8.625 9.875 11.125\n"
" 12.375 13.562 14.812 16.000 17.250\n"
" 18.500 19.750 21.000 22.250 23.500\n"
" 24.750 25.875 27.125 28.375 29.625\n"));
}
TEST(Debug, PrintTensorFP8)
{
auto desc =
ckt::make_descriptor<ckb::DataType::FP8>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(
desc, a.get(), [](size_t i) { return ck::type_convert<ck::f8_t>(i * 0.1f); });
std::stringstream ss;
ckt::print_tensor("FP8", desc, a.get(), {}, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"FP8\": shape = [5, 5]\n"
" 0.000 0.102 0.203 0.312 0.406\n"
" 0.500 0.625 0.688 0.812 0.875\n"
" 1.000 1.125 1.250 1.250 1.375\n"
" 1.500 1.625 1.750 1.750 1.875\n"
" 2.000 2.000 2.250 2.250 2.500\n"));
}
TEST(Debug, PrintTensorSpecialFloats)
{
auto desc =
ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) {
if(i % 8 == 1)
return 0.f / 0.f;
else if(i % 7 == 1)
return std::sqrt(-1.f);
else if(i % 6 == 1)
return 1.f / 0.f;
else if(i % 5 == 1)
return -1.f / 0.f;
else
return static_cast<float>(i);
});
std::stringstream ss;
ckt::print_tensor("specials", desc, a.get(), {}, ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"specials\": shape = [5, 5]\n"
" 0.000 nan 2.000 3.000 4.000\n"
" 5.000 -inf inf -nan nan\n"
" 10.000 -inf 12.000 inf 14.000\n"
" -nan -inf nan 18.000 inf\n"
" 20.000 -inf -nan 23.000 24.000\n"));
}
TEST(Debug, PrintTensorFloatPrecision)
{
auto desc = ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(0.9, i); });
std::stringstream ss;
ckt::print_tensor("FloatPrecision",
desc,
a.get(),
{
.float_precision = 10,
},
ss);
EXPECT_THAT(ss.str(),
StringEqWithDiff( //
"Tensor \"FloatPrecision\": shape = [5]\n"
" 1.0000000000 0.8999999762 0.8100000024 0.7289999723 0.6560999751\n"));
}

View File

@@ -88,3 +88,11 @@ TEST(DeviceBuffer, AllocTensorBuffer)
EXPECT_THAT(hipMemset(buffer.get(), 0xFF, descriptor.get_element_space_size_in_bytes()),
HipSuccess());
}
TEST(DeviceBuffer, AlignForward)
{
EXPECT_THAT(ckt::align_fwd(24, 8), Eq(24));
EXPECT_THAT(ckt::align_fwd(25, 8), Eq(32));
EXPECT_THAT(ckt::align_fwd(0xd7c563, 0x1000), Eq(0xd7d000));
EXPECT_THAT(ckt::align_fwd(19561, 23), Eq(19573));
}

View File

@@ -6,11 +6,13 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <array>
#include <sstream>
#include <vector>
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ck_tile::test::StringEqWithDiff;
using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::Throws;
@@ -76,7 +78,7 @@ TEST(TensorDescriptor, MakeDescriptor)
// Note: automatic inference of RANK.
const auto desc =
ckt::make_descriptor<ckb::DataType::INT32>(lengths, ckt::PackedRightLayout{});
ckt::make_descriptor<ckb::DataType::I32>(lengths, ckt::PackedRightLayout{});
EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths));
EXPECT_THAT(desc.get_strides(),
@@ -173,7 +175,7 @@ TEST(TensorDescriptor, ExtentFromVector)
TEST(TensorDescriptor, IsPacked)
{
constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test
constexpr auto dt = ckb::DataType::I32; // Irrelevant for this test
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{})
.is_packed());
@@ -189,3 +191,20 @@ TEST(TensorDescriptor, IsPacked)
EXPECT_FALSE(
ckt::make_descriptor<dt>(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed());
}
TEST(TensorDescriptor, PrintExtent)
{
{
const ckt::Extent extent{6233, 55, 1235, 52, 203};
std::stringstream ss;
ss << extent;
EXPECT_THAT(ss.str(), StringEqWithDiff("[6233, 55, 1235, 52, 203]"));
}
{
const ckt::Extent extent{};
std::stringstream ss;
ss << extent;
EXPECT_THAT(ss.str(), StringEqWithDiff("[]"));
}
}

View File

@@ -16,6 +16,28 @@ namespace ckt = ck_tile::builder::test;
using ::testing::Each;
using ::testing::Eq;
TEST(TensorForeach, NdIter)
{
{
ckt::NdIter iter(ckt::Extent{523, 345, 123, 601});
EXPECT_THAT(iter.numel(), Eq(13'338'296'505ULL));
EXPECT_THAT(iter(0), Eq(ckt::Extent{0, 0, 0, 0}));
EXPECT_THAT(iter(1), Eq(ckt::Extent{0, 0, 0, 1}));
EXPECT_THAT(iter(601), Eq(ckt::Extent{0, 0, 1, 0}));
EXPECT_THAT(iter(601 * 123), Eq(ckt::Extent{0, 1, 0, 0}));
EXPECT_THAT(iter(601 * 123 * 10), Eq(ckt::Extent{0, 10, 0, 0}));
EXPECT_THAT(iter(((34 * 345 + 63) * 123 + 70) * 601 + 5), Eq(ckt::Extent{34, 63, 70, 5}));
}
{
ckt::NdIter iter(ckt::Extent{});
EXPECT_THAT(iter.numel(), Eq(1));
EXPECT_THAT(iter(0), Eq(ckt::Extent{}));
}
}
TEST(TensorForeach, CalculateOffset)
{
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123));
@@ -87,8 +109,8 @@ TEST(TensorForeach, VisitsEveryIndex)
TEST(TensorForeach, FillTensorBuffer)
{
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(ckt::Extent{31, 54, 13},
ckt::PackedRightLayout{});
auto desc =
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{31, 54, 13}, ckt::PackedRightLayout{});
auto buffer = ckt::alloc_tensor_buffer(desc);
@@ -109,7 +131,7 @@ TEST(TensorForeach, FillTensor)
// FillTensor with non-packed indices should not write out-of-bounds.
const ckt::Extent shape = {4, 23, 35};
const ckt::Extent pad = {12, 53, 100};
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{}(pad));
const auto strides = desc.get_strides();
auto size = desc.get_element_space_size();
@@ -169,7 +191,7 @@ TEST(TensorForeach, ClearTensorZeros)
const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6};
const auto desc =
ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{}(pad));
auto buffer = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, buffer.get());

View File

@@ -173,8 +173,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
}
{
auto desc = ckt::make_descriptor<ckb::DataType::INT32, 3>({'G', 'P', 'U'},
ckt::PackedRightLayout{});
auto desc =
ckt::make_descriptor<ckb::DataType::I32, 3>({'G', 'P', 'U'}, ckt::PackedRightLayout{});
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
@@ -204,6 +204,7 @@ struct DummySignature
constexpr DummySignature DUMMY_SIGNATURE = {};
namespace ck_tile::builder::test {
template <>
struct Args<DUMMY_SIGNATURE>
{
@@ -225,6 +226,7 @@ struct Outputs<DUMMY_SIGNATURE>
void* b;
};
// Explicitly implement validate for this type to test that that works.
template <>
ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
Outputs<DUMMY_SIGNATURE> actual,

View File

@@ -15,31 +15,42 @@ using namespace test;
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{
.k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{
.thread_slice_lengths = {8, 1, 1, 2},
.thread_cluster_lengths = {2, 1, 128, 1},
.thread_cluster_arrange_order = {1, 2, 0, 3},
.src_access_order = {1, 2, 0, 3},
.src_vector_tensor_lengths = {4, 1, 1, 2},
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
constexpr DlTransferABC DlFwdTransfer{.a =
{
.block_transfer = DlBlockTransferAB,
},
.b =
{
.block_transfer = DlBlockTransferAB,
},
.c = {
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4},
}};
constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2,
.b = DlBlockTransfer_8x1x1x2,
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 4}};
constexpr TransferABC FwdTransfer_4x64x1{
constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{
.thread_slice_lengths = {1, 8, 1, 1, 1},
.thread_cluster_lengths = {1, 2, 1, 128, 1},
.thread_cluster_arrange_order = {0, 2, 3, 1, 4},
.src_access_order = {0, 2, 3, 1, 4},
.src_vector_tensor_lengths = {1, 1, 1, 1, 1},
.src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4},
.dst_vector_tensor_lengths = {1, 1, 1, 1, 1}};
constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
.b = DlBlockTransfer_1x8x1x1x1,
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
.src_dst_vector_dim = 5,
.dst_scalar_per_vector = 1}};
constexpr Transfer<> Transfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{
},
};
constexpr TransferABC FwdTransfer_4x64x1_fp8{
constexpr Transfer<4> BwdTransfer_4x64x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
.lds_transfer = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 4,
.is_direct_load = false,
.lds_padding = true},
.block_transfer_access_order = {0, 3, 1, 2},
.src_access_order = {0, 2, 1, 3},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
},
};
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
.lds_transfer = {.src_vector_dim = 1,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 2,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {2, 0, 1},
.src_access_order = {1, 0, 2},
},
.b =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
.lds_transfer = {.src_vector_dim = 1,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 2,
.is_direct_load = false,
.lds_padding = false},
.block_transfer_access_order = {2, 0, 1},
.src_access_order = {1, 0, 2},
},
.c =
{
.thread_cluster_dims =
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
.n_per_wave_per_shuffle = 1,
.scalar_per_vector = 2},
},
};
constexpr Transfer<> Transfer_4x64x1_fp8{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
@@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{
},
};
constexpr TransferABC FwdTransfer_4x16x1{
constexpr Transfer<> Transfer_4x16x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
@@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{
},
};
constexpr TransferABC FwdTransfer_4x32x1{
constexpr Transfer<> Transfer_4x32x1{
.a =
{
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
@@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{
},
};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
.k1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
.m_per_wmma = 32,
.n_per_wmma = 32,
.m_wmma_per_wave = 2,
.n_wmma_per_wave = 1,
.pipeline_version = PipelineVersion::V1};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
.ak1 = 8,
.bk1 = 8,
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
.k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256,
.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 8}};
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
.tile_size = {.m = 32, .n = 32, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128,
.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {
.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {
.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {
.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {
.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE};
constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {
.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils

View File

@@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr TileTransfer FwdTileTransfer_1x1x1{
constexpr TileTransfer TileTransfer_1x1x1{
.a_scalar_per_vector = 1,
.b_scalar_per_vector = 1,
.c_scalar_per_vector = 1,
};
constexpr TileTransfer FwdTileTransfer_4x4x4{
constexpr TileTransfer TileTransfer_4x4x4{
.a_scalar_per_vector = 4,
.b_scalar_per_vector = 4,
.c_scalar_per_vector = 4,
};
constexpr TileTransfer FwdTileTransfer_8x8x8{
constexpr TileTransfer TileTransfer_8x8x8{
.a_scalar_per_vector = 8,
.b_scalar_per_vector = 8,
.c_scalar_per_vector = 8,
};
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},

View File

@@ -54,7 +54,7 @@ inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
}
template <>
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
inline std::string to_string<ConvSpecialization>(ConvSpecialization t)
{
std::ostringstream oss;
oss << t;
@@ -86,11 +86,20 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
}
template <>
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
{
std::ostringstream oss;
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
return oss.str();
}
template <>
inline std::string to_string<GridwiseFwdXdlGemm>(GridwiseFwdXdlGemm t)
{
std::ostringstream oss;
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
return oss.str();
}
@@ -104,17 +113,29 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
}
template <>
inline std::string to_string<BlockGemm>(BlockGemm t)
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
{
std::ostringstream oss;
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
return oss.str();
}
template <>
inline std::string to_string<BlockTransfer>(BlockTransfer t)
template <size_t ThreadClusterRank>
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
if constexpr(ThreadClusterRank == 4)
{
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
}
else if constexpr(ThreadClusterRank == 3)
{
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
}
else
{
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4,
"Unsupported ThreadClusterRank");
}
}
template <>
@@ -134,14 +155,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
return oss.str();
}
template <>
inline std::string to_string<AccessOrder>(AccessOrder t)
template <size_t N>
inline std::string to_string(AccessOrder<N> t)
{
return array_to_seq(t.order);
}
template <>
inline std::string to_string<TransferAB>(TransferAB t)
template <size_t N = 3>
inline std::string to_string(InputTransfer<N> t)
{
std::ostringstream oss;
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
@@ -152,7 +173,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
}
template <>
inline std::string to_string<TransferC>(TransferC t)
inline std::string to_string<OutputTransfer>(OutputTransfer t)
{
std::ostringstream oss;
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
@@ -160,8 +181,8 @@ inline std::string to_string<TransferC>(TransferC t)
return oss.str();
}
template <>
inline std::string to_string<TransferABC>(TransferABC t)
template <size_t N = 3>
inline std::string to_string(Transfer<N> t)
{
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
@@ -185,7 +206,19 @@ inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
}
template <>
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
inline std::string to_string<DlBlockTransfer<4>>(DlBlockTransfer<4> t)
{
std::ostringstream oss;
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
<< array_to_seq(t.dst_vector_tensor_lengths);
return oss.str();
}
template <>
inline std::string to_string<DlBlockTransfer<5>>(DlBlockTransfer<5> t)
{
std::ostringstream oss;
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
@@ -206,19 +239,24 @@ inline std::string to_string<DlEpilogue>(DlEpilogue t)
}
template <>
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
inline std::string to_string<TransposeParams_>(TransposeParams_ t)
{
return to_string(t.block_transfer);
std::ostringstream oss;
oss << t.max_transpose_transfer_src_scalar_per_vector << ","
<< t.max_transpose_transfer_dst_scalar_per_vector;
return oss.str();
}
template <>
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
inline std::string to_string<DlTransfer<4>>(DlTransfer<4> t)
{
return to_string(t.epilogue);
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
return oss.str();
}
template <>
inline std::string to_string<DlTransferABC>(DlTransferABC t)
inline std::string to_string<DlTransfer<5>>(DlTransfer<5> t)
{
std::ostringstream oss;
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
@@ -234,7 +272,13 @@ inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
}
template <>
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
inline std::string to_string<FwdXdlGemm_>(FwdXdlGemm_ t)
{
return to_string(t.gridwise_gemm);
}
template <>
inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
{
return to_string(t.gridwise_gemm);
}
@@ -245,33 +289,40 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
return to_string(t.gridwise_gemm);
}
template <>
inline std::string to_string<Transfer_>(Transfer_ t)
template <size_t ThreadClusterRank = 3>
inline std::string to_string(Transfer_<ThreadClusterRank> t)
{
return to_string(t.transfer);
}
template <>
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
inline std::string to_string<ConvSpecializationFwd_>(ConvSpecializationFwd_ t)
{
std::ostringstream oss;
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
return oss.str();
}
template <>
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
{
std::ostringstream oss;
oss << to_string(t.bwd_weight_specialization);
return oss.str();
}
template <>
inline std::string to_string<Prefetch_>(Prefetch_ t)
{
std::ostringstream oss;
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
<< to_string(t.loop_scheduler);
oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler);
return oss.str();
}
template <>
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
{
return to_string(t.block_gemm);
return to_string(t.block_gemm_pipeline);
}
template <>
@@ -287,7 +338,13 @@ inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
}
template <>
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
inline std::string to_string<DlTransfer_<4>>(DlTransfer_<4> t)
{
return to_string(t.transfer);
}
template <>
inline std::string to_string<DlTransfer_<5>>(DlTransfer_<5> t)
{
return to_string(t.transfer);
}
@@ -299,8 +356,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -309,8 +366,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -320,7 +377,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_>(t));
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
@@ -332,7 +389,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
<< to_string(static_cast<DlTransfer_>(t));
<< to_string(static_cast<DlTransfer_<4>>(t));
return oss.str();
}
@@ -340,7 +397,102 @@ template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
{
return to_string(t.base_algorithm);
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<4>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl>(
ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
<< to_string(static_cast<DlTransfer_<5>>(t));
return oss.str();
}
template <>
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle>(
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t)
{
std::ostringstream oss;
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
<< "," << to_string(static_cast<Transfer_<4>>(t));
return oss.str();
}
} // namespace ck_tile::builder::test

View File

@@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
static TailNumber BlockLoopTailNum(index_t num_loop)
{
@@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
__host__ __device__ static bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
static TailNumber BlockLoopTailNum(index_t num_loop)
{

View File

@@ -3,6 +3,11 @@
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "device_grouped_gemm.hpp"
namespace ck {
@@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
{
};
template <ck::index_t BlockSize>
struct TileLoopKernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
static int GetCuBlocks()
{
int BLOCK_WAVES = BlockSize / get_warp_size();
return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
}
template <typename KernelFunction>
static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config)
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int occ_num_blocks = GetKernelOccupancy(kernel);
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
<< ", available CUs count: " << cu_count << ", occup. grid size: "
<< ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl;
}
return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks());
}
template <typename KernelFunction>
static int GetKernelOccupancy(const KernelFunction& kernel)
{
int occupancy = 0;
ck::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
return occupancy;
}
static int GetComputeUnitCount()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck::hip_check_error(hipGetDevice(&dev));
ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

Some files were not shown because too many files have changed in this diff Show More