mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge branch 'develop' into LWPCK-3549-cleanups
This commit is contained in:
@@ -43,6 +43,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added top-k sigmoid kernel in CK_TILE
|
||||
* Added the blockscale 2D support for CK_TILE GEMM.
|
||||
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
|
||||
* Added reduce and multi reduction kernels
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
62
Jenkinsfile
vendored
62
Jenkinsfile
vendored
@@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){
|
||||
def setup_cmd
|
||||
def build_cmd
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
//check the node gpu architecture
|
||||
def arch_name = check_arch_name()
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
if (params.NINJA_BUILD_TRACE) {
|
||||
echo "running ninja build trace"
|
||||
@@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
|
||||
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
|
||||
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json"
|
||||
archiveArtifacts "ck_build_trace_${check_arch_name()}.json"
|
||||
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json"
|
||||
sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json"
|
||||
archiveArtifacts "ck_build_trace_${arch_name}.json"
|
||||
sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json"
|
||||
if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){
|
||||
if (params.NINJA_FTIME_TRACE) {
|
||||
echo "running ClangBuildAnalyzer"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log"
|
||||
archiveArtifacts "clang_build_analysis_${check_arch_name()}.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log"
|
||||
archiveArtifacts "clang_build_analysis_${arch_name}.log"
|
||||
}
|
||||
|
||||
|
||||
@@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
@@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){
|
||||
if(params.BUILD_PACKAGES){
|
||||
echo "Build ckProfiler packages"
|
||||
sh 'ninja -j64 package'
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}"
|
||||
sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb"
|
||||
stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//check the node gpu architecture
|
||||
def arch_name = check_arch_name()
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
@@ -1201,8 +1201,8 @@ pipeline {
|
||||
description: "Run the ck_tile FMHA tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_TILE_ENGINE_BASIC_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the tile_engine_basic tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run the tile_engine_basic tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -1650,7 +1650,10 @@ pipeline {
|
||||
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" \
|
||||
-D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1667,37 +1670,6 @@ pipeline {
|
||||
}
|
||||
parallel
|
||||
{
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_STREAMK_LAYOUT="rcr" \
|
||||
-D GEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
|
||||
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
|
||||
ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
|
||||
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDs = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
@@ -71,339 +71,6 @@ using DeviceGemmInstance =
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
|
||||
@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=no, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)\n");
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
printf("arg6: k-batch count (default=1)\n");
|
||||
|
||||
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an
|
||||
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
|
||||
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
@@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
|
||||
#else
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
@@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -182,12 +184,14 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t N = 1536;
|
||||
ck::index_t K = 4096;
|
||||
// ck::index_t N = 4096;
|
||||
// ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
ck::index_t experts = 16;
|
||||
ck::index_t topk = 8;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 208;
|
||||
@@ -196,9 +200,9 @@ int main(int argc, char* argv[])
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 2;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t tokens = 32;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 16;
|
||||
ck::index_t tokens = 4;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
@@ -209,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
ck::index_t KBatch = 6;
|
||||
ck::index_t KBatch = 1;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
|
||||
@@ -36,7 +36,7 @@ DTYPE_BITS = {
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [128, 256, 1024]
|
||||
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
@@ -737,6 +737,8 @@ def get_fwd_blobs(
|
||||
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
|
||||
@@ -1351,8 +1351,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
|
||||
@@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Threadwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# Multi Reduce Blockwise Example
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
|
||||
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
|
||||
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
271
example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "19", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = float;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
// Operations: one doing a sum reduction, the other computing the mean square
|
||||
// In the case of mean square:
|
||||
// 1. The element wise operation squares each element before reduction
|
||||
// 2. The reduction operation sum the squared element
|
||||
// 3. The accumulator element wise operation divides the result by the total number of reduced
|
||||
// elements (intra block operation)
|
||||
// 4. The partial result is updated across blocks using inter block reduction, a sum.
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
|
||||
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise
|
||||
// ops
|
||||
auto accumulator_elementwise_ops = ck_tile::make_tuple(
|
||||
ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
|
||||
auto inter_block_reduce_ops = ck_tile::make_tuple(
|
||||
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceMultiblock<Problem>;
|
||||
|
||||
// Determine block group size for multi-block reduction
|
||||
// block_group_size records how many blocks participate to a reduction (input data dependent)
|
||||
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
|
||||
// to process the whole reduction, each thread will to process multiple thread tile
|
||||
// a num_block_tile_iterations times
|
||||
auto [num_block_tile_iterations, block_group_size] =
|
||||
typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams();
|
||||
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
ck_tile::index_t kGridSize =
|
||||
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;
|
||||
|
||||
std::cout << "Block group size: " << block_group_size
|
||||
<< ", Num block tile iterations: " << num_block_tile_iterations
|
||||
<< ", Reduce total length: " << reduce_total_length << std::endl;
|
||||
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
// Init the output data with identity values respective to each reduce op
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
constexpr auto op = reduce_ops.at(i);
|
||||
const auto identity_val = op.template GetIdentityValue<YDataType>();
|
||||
const auto output_number_elements = N * C;
|
||||
std::fill(h.begin() + i * output_number_elements,
|
||||
h.begin() + (i + 1) * output_number_elements,
|
||||
identity_val);
|
||||
});
|
||||
|
||||
auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };
|
||||
|
||||
float ave_time = launch_kernel_time_mask(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
clear_output_buffer,
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops)
|
||||
|
||||
);
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops,
|
||||
inter_block_reduce_ops,
|
||||
block_group_size);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
std::cout << "Checking operation " << i << ": " << std::endl;
|
||||
|
||||
bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
|
||||
if(pass_op)
|
||||
{
|
||||
std::cout << "✅ valid results for this operation" << std::endl;
|
||||
}
|
||||
pass &= pass_op;
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
224
example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("n", "32", "n dimension")
|
||||
.insert("h", "7", "h dimension")
|
||||
.insert("w", "7", "w dimension")
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "multi_reduce.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t H = arg_parser.get_int("h");
|
||||
ck_tile::index_t W = arg_parser.get_int("w");
|
||||
ck_tile::index_t C = arg_parser.get_int("c");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
// Validate input dimensions
|
||||
const ck_tile::index_t kept_dim_len_prod = N * C;
|
||||
const ck_tile::index_t reduce_total_length = H * W;
|
||||
|
||||
if(kept_dim_len_prod == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C
|
||||
<< ", product=" << kept_dim_len_prod << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty output tensor." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if(reduce_total_length == 0)
|
||||
{
|
||||
std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W
|
||||
<< ", product=" << reduce_total_length << ")." << std::endl;
|
||||
std::cerr << "This will result in an empty reduction with no data to process." << std::endl;
|
||||
std::cerr << "The kernel will exit early without performing any computation." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
|
||||
std::vector<ck_tile::index_t> strides(4);
|
||||
strides[0] = H * W * C;
|
||||
strides[1] = W * C;
|
||||
strides[2] = C;
|
||||
strides[3] = 1;
|
||||
|
||||
// Define reduction specification:
|
||||
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
|
||||
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
|
||||
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
|
||||
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
|
||||
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);
|
||||
|
||||
const auto number_operations = y_host_dev_tuple.size();
|
||||
|
||||
// Two operations: one do a sum reduction, the other computing the mean square
|
||||
auto reduce_ops =
|
||||
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops
|
||||
auto elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnarySquare{}); // Elementwise ops
|
||||
auto accumulator_elementwise_ops =
|
||||
ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
|
||||
ck_tile::element_wise::UnaryDivide{
|
||||
reduce_total_length}); // Accumulator Elementiwise ops on reduction,
|
||||
|
||||
auto y_buf_size = number_operations *
|
||||
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
|
||||
ck_tile::DeviceMem y_buf(y_buf_size);
|
||||
|
||||
const auto output_tensor_offset = N * C;
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 128>;
|
||||
using WarpTile = ck_tile::sequence<32, 128>;
|
||||
using ThreadTile = ck_tile::sequence<8, 8>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
decltype(reduce_ops),
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
4>;
|
||||
|
||||
using Kernel = ck_tile::MultiReduceThreadWise<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
// Create input tensor shape and strides
|
||||
auto input_shape =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(
|
||||
C, input_strides)) // output tensor's continuous dimension and input strides
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
output_tensor_offset,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
std::vector<YDataType> h(number_operations * N * C);
|
||||
|
||||
// reference
|
||||
ck_tile::reference_multiple_reduce<XDataType, ComputeDataType, YDataType>(
|
||||
x_host,
|
||||
y_host_ref_tuple,
|
||||
reduce_ops,
|
||||
kept_dim,
|
||||
reduce_dims,
|
||||
elementwise_ops,
|
||||
accumulator_elementwise_ops);
|
||||
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;
|
||||
|
||||
// Transfer data from device and check error for each operation
|
||||
y_buf.FromDevice(h.data());
|
||||
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
|
||||
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
|
||||
h.data() + i * output_tensor_offset,
|
||||
output_tensor_offset * sizeof(YDataType));
|
||||
pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
|
||||
y_host_ref_tuple.get(ck_tile::number<i>{}));
|
||||
});
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
@@ -45,6 +45,11 @@ cmake
|
||||
..
|
||||
```
|
||||
|
||||
Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g.
|
||||
`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively,
|
||||
one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use
|
||||
the instances from builder, one needs an actual Navi card.
|
||||
|
||||
## Building and Testing
|
||||
|
||||
The builder test suite is organized into two main categories:
|
||||
|
||||
@@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
|
||||
};
|
||||
```
|
||||
|
||||
**Properties:**
|
||||
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
|
||||
- **`direction`**: Operation type (optional, defaults to FORWARD)
|
||||
- **`direction`**: Operation type (Optional, defaults to FORWARD)
|
||||
- `FORWARD`: Standard forward convolution
|
||||
- `BACKWARD_DATA`: Gradient computation w.r.t. input
|
||||
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)
|
||||
- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)
|
||||
- **`accumulation_data_type`**: Type used for internal accumulation
|
||||
|
||||
#### 2. Tensor Level
|
||||
@@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) {
|
||||
|
||||
A tensor descriptor encapsulates:
|
||||
- **Configuration**: Layout and data type information
|
||||
- **Operation** (optional): Fused elementwise operations on this tensor
|
||||
- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
|
||||
|
||||
#### 3. Tensor Configuration
|
||||
|
||||
@@ -126,7 +128,7 @@ Describes the memory layout and data types:
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<ConvLayout>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
@@ -15,29 +15,31 @@ namespace ck_tile::builder {
|
||||
/* Descriptors for individual elements of the algorithm description */
|
||||
/********************************************************************/
|
||||
|
||||
// Common concept for size-related fields
|
||||
template <typename T>
|
||||
concept SizeType = std::unsigned_integral<std::remove_cvref_t<T>>;
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
{ t.block_size } -> SizeType;
|
||||
{ t.tile_size.m } -> SizeType;
|
||||
{ t.tile_size.n } -> SizeType;
|
||||
{ t.tile_size.k } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise XDL GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> SizeType;
|
||||
{ t.n_per_xdl } -> SizeType;
|
||||
{ t.m_xdl_per_wave } -> SizeType;
|
||||
{ t.n_xdl_per_wave } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for parameter that describe block GEMM problem.
|
||||
template <typename T>
|
||||
concept BlockGemmDescriptor = requires(T t) {
|
||||
concept BlockGemmPipelineDescriptor = requires(T t) {
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
@@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) {
|
||||
// Concept for parameters that describe a gridwise WMMA GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseWmmaGemmDescriptor = requires(T t) {
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_wmma } -> std::convertible_to<size_t>;
|
||||
{ t.m_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_wmma_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.m_per_wmma } -> SizeType;
|
||||
{ t.n_per_wmma } -> SizeType;
|
||||
{ t.m_wmma_per_wave } -> SizeType;
|
||||
{ t.n_wmma_per_wave } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
concept BlockTransferDescriptor3D = requires(T t) {
|
||||
{ t.k0 } -> SizeType;
|
||||
{ t.m_n } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor4D = requires(T t) {
|
||||
{ t.k0 } -> SizeType;
|
||||
{ t.m_n } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.k_batch_size } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T, size_t ThreadClusterRank>
|
||||
concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D<T>) ||
|
||||
(ThreadClusterRank == 4 && BlockTransferDescriptor4D<T>);
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
template <typename T>
|
||||
concept ThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<size_t>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_block } -> std::convertible_to<size_t>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_block } -> SizeType;
|
||||
{ t.m_wave_per_xdl } -> SizeType;
|
||||
{ t.n_block } -> SizeType;
|
||||
{ t.n_wave_per_xdl } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for the LDS transfer for the convolution input tensors.
|
||||
template <typename T>
|
||||
concept LdsTransferDescriptor = requires(T t) {
|
||||
{ t.src_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.src_vector_dim } -> SizeType;
|
||||
{ t.src_scalar_per_vector } -> SizeType;
|
||||
{ t.lds_dst_scalar_per_vector } -> SizeType;
|
||||
{ t.is_direct_load } -> std::convertible_to<bool>;
|
||||
{ t.lds_padding } -> std::convertible_to<bool>;
|
||||
};
|
||||
@@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) {
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.n_per_wave_per_shuffle } -> SizeType;
|
||||
{ t.scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
} || requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileThreadBlockDescriptor = requires(T t) {
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> SizeType;
|
||||
{ t.tile_size.n } -> SizeType;
|
||||
{ t.tile_size.k } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
|
||||
// size is deduced from block gemm structure).
|
||||
template <typename T>
|
||||
concept TileTransferDescriptor = requires(T t) {
|
||||
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.a_scalar_per_vector } -> SizeType;
|
||||
{ t.b_scalar_per_vector } -> SizeType;
|
||||
{ t.c_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
@@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires {
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseXdlGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseXdlGemmDescriptor;
|
||||
concept GridwiseFwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> SizeType;
|
||||
{ t.bk1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept GridwiseBwdXdlGemmDescriptor = requires(T t) {
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.xdl_params } -> GridwiseXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseFwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise XDL GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseBwdXdlGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise WMMA GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseWmmaGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
|
||||
concept SpecifiesGridwiseWmmaGemm = requires(T t) {
|
||||
{ t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
template <typename T, size_t ThreadClusterRank = 3>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor<ThreadClusterRank>;
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransfer = requires(T t) {
|
||||
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ T::transfer.a_scalar_per_vector } -> SizeType;
|
||||
{ T::transfer.b_scalar_per_vector } -> SizeType;
|
||||
{ T::transfer.c_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
@@ -210,8 +246,12 @@ concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockGemm = requires {
|
||||
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
{ T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseGemmPipeline = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<PipelineVersion>;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM (CK Tile).
|
||||
@@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConvSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesBwdWeightConvSpecialization = requires {
|
||||
{ T::bwd_weight_specialization } -> std::convertible_to<ConvSpecialization>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires {
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesNumPrefetchStages = requires {
|
||||
{ T::num_gemm_k_prefetch_stages } -> std::convertible_to<size_t>;
|
||||
{ T::num_gemm_k_prefetch_stages } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesNumGroupsToMerge = requires {
|
||||
{ T::num_groups_to_merge } -> std::convertible_to<size_t>;
|
||||
{ T::num_conv_groups_to_merge } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires {
|
||||
{ T::loop_scheduler } -> std::convertible_to<PipelineScheduler>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGenericInstance = !requires {
|
||||
{ T::specialization };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTransposeTransfer = requires {
|
||||
{ T::max_transpose_transfer_src_scalar_per_vector } -> SizeType;
|
||||
{ T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasTransposeTransfer = requires {
|
||||
{ T::max_transpose_transfer_src_scalar_per_vector };
|
||||
{ T::max_transpose_transfer_dst_scalar_per_vector };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept TransposeTransferWellDefinedIfProvided =
|
||||
!HasTransposeTransfer<T> || SpecifiesTransposeTransfer<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesGemmBatchOptions = requires {
|
||||
{ T::num_conv_groups_to_merge } -> SizeType;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* Algorithm specialization concepts */
|
||||
/******************************************** */
|
||||
template <typename T>
|
||||
concept SpecifiesLargeTensorSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesReferenceAlgorithm = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTwoStageSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesMultipleDSupport = requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
/* DL-specific descriptors and requirements */
|
||||
/******************************************** */
|
||||
@@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires {
|
||||
// Concept for DL thread configuration
|
||||
template <typename T>
|
||||
concept DlThreadConfigDescriptor = requires(T t) {
|
||||
{ t.k0_per_block } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
{ t.m1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.n1_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k_per_thread } -> std::convertible_to<size_t>;
|
||||
{ t.k0_per_block } -> SizeType;
|
||||
{ t.k1 } -> SizeType;
|
||||
{ t.m1_per_thread } -> SizeType;
|
||||
{ t.n1_per_thread } -> SizeType;
|
||||
{ t.k_per_thread } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept for DL thread cluster
|
||||
@@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) {
|
||||
};
|
||||
|
||||
// Concept for DL block transfer
|
||||
template <typename T>
|
||||
template <typename T, size_t N>
|
||||
concept DlBlockTransferDescriptor = requires(T t) {
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, 4>>;
|
||||
{ t.thread_slice_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.thread_cluster_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.thread_cluster_arrange_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_access_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to<std::array<size_t, N>>;
|
||||
{ t.dst_vector_tensor_lengths } -> std::convertible_to<std::array<size_t, N>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor<T, 4>;
|
||||
|
||||
template <typename T>
|
||||
concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor<T, 5>;
|
||||
|
||||
// Concept for DL epilogue
|
||||
template <typename T>
|
||||
concept DlEpilogueDescriptor = requires(T t) {
|
||||
{ t.src_dst_access_order } -> std::convertible_to<std::array<size_t, 6>>;
|
||||
{ t.src_dst_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.src_dst_vector_dim } -> SizeType;
|
||||
{ t.dst_scalar_per_vector } -> SizeType;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL thread config
|
||||
@@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires {
|
||||
|
||||
// Concept to check if algorithm specifies DL block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransfer = requires {
|
||||
{ T::transfer.a.block_transfer } -> DlBlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> DlBlockTransferDescriptor;
|
||||
concept SpecifiesDlFwdBlockTransfer = requires {
|
||||
{ T::transfer.a } -> DlBlockTransferDescriptor4D;
|
||||
{ T::transfer.b } -> DlBlockTransferDescriptor4D;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesDlBwdBlockTransfer = requires {
|
||||
{ T::transfer.a } -> DlBlockTransferDescriptor5D;
|
||||
{ T::transfer.b } -> DlBlockTransferDescriptor5D;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlEpilogue = requires {
|
||||
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
|
||||
{ T::transfer.c } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -29,10 +29,20 @@ concept OutputVectorTransferLimits = requires {
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
concept AccessOrderLimits3D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
(Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3));
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2, 3}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits4D = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) &&
|
||||
(Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) &&
|
||||
(Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) &&
|
||||
(Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) &&
|
||||
(Value.Size() == 4));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -80,6 +80,7 @@ concept ConvOutputLayout3D =
|
||||
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
|
||||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
concept HasDataType = requires(T t) {
|
||||
{ t.data_type };
|
||||
@@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -116,7 +118,6 @@ template <typename T, std::size_t N>
|
||||
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept ConvertibleToArrayOfTensorConfigs =
|
||||
@@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
|
||||
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
|
||||
};
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) {
|
||||
{ t.operation };
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
concept HasConvolutionDirection = requires(T t) {
|
||||
{ t.direction };
|
||||
@@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Concept for the convolution tensor
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
@@ -221,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && ConvWeightLayout1D<L>) || (SpatialDim == 2 && ConvWeightLayout2D<L>) ||
|
||||
(SpatialDim == 3 && ConvWeightLayout3D<L>);
|
||||
|
||||
// Constraint for 3D conv signature.
|
||||
template <auto Sig>
|
||||
concept Is3D = requires {
|
||||
requires Sig.spatial_dim == 3;
|
||||
requires ConvInputLayout3D<Sig.input.config.layout>;
|
||||
requires ConvOutputLayout3D<Sig.output.config.layout>;
|
||||
requires ConvWeightLayout3D<Sig.weight.config.layout>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Base algorithm concepts
|
||||
template <typename T, size_t ThreadClusterRank = 3>
|
||||
concept TileTransferParameters =
|
||||
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
|
||||
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesTileTransferParameters4D = TileTransferParameters<T, 4>;
|
||||
|
||||
template <typename T>
|
||||
concept FwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters4D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3AlgorithmBase =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
|
||||
SpecifiesBlockGemm<T>;
|
||||
|
||||
// Reference algorithm concept
|
||||
template <typename T>
|
||||
concept ReferenceAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesReferenceAlgorithm<T>;
|
||||
|
||||
// Tile-based algorithm concept
|
||||
template <typename T>
|
||||
concept TileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// FWD XDL algorithm concepts
|
||||
template <typename T>
|
||||
concept FwdXdlAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept LargeTensorAlgorithm = FwdXdlAlgorithmBase<T> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept FwdXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseFwdXdlGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// FWD WMMA algorithm concepts
|
||||
template <typename T>
|
||||
concept FwdWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
|
||||
SpecifiesGridwiseWmmaGemm<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
|
||||
SpecifiesGridwiseGemmPipeline<T>;
|
||||
|
||||
// FWD DL algorithms
|
||||
template <typename T>
|
||||
concept FwdDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlFwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// BWD weight XDL algorithm concepts
|
||||
template <typename T>
|
||||
concept BwdXdlAlgorithm =
|
||||
BwdXdlAlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
|
||||
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
|
||||
|
||||
// BWD weight WMMA algorithm concepts
|
||||
template <typename T>
|
||||
concept BwdWmmaAlgorithm =
|
||||
BwdWmmaAlgorithmBase<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T> &&
|
||||
SpecifiesGridwiseGemmPipeline<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesMultipleDSupport<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdWmmaV3Algorithm =
|
||||
BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> && SpecifiesGenericInstance<T>;
|
||||
|
||||
template <typename T>
|
||||
concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase<T> && SpecifiesTransposeTransfer<T> &&
|
||||
SpecifiesGemmBatchOptions<T> && SpecifiesTwoStageSupport<T>;
|
||||
|
||||
// BWD weigth DL algorithms
|
||||
template <typename T>
|
||||
concept BwdDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
|
||||
SpecifiesBwdWeightConvSpecialization<T> && SpecifiesDlThreadConfig<T> &&
|
||||
SpecifiesDlThreadCluster<T> && SpecifiesDlBwdBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Dl instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
|
||||
// DL-specific parameters from algorithm descriptor
|
||||
static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config;
|
||||
static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block;
|
||||
static constexpr ck::index_t K1 = DL_THREAD_CFG.k1;
|
||||
static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread;
|
||||
static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread;
|
||||
static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread;
|
||||
|
||||
// Thread cluster from descriptor
|
||||
static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster;
|
||||
using M1N1ThreadClusterM1Xs = to_sequence_v<DL_CLUSTER.m1_xs>;
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_lengths>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_cluster_arrange_order>;
|
||||
using ABlockTransferSrcAccessOrder = to_sequence_v<DL_A_TRANSFER.src_access_order>;
|
||||
using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_lengths>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_A_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_lengths>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_cluster_arrange_order>;
|
||||
using BBlockTransferSrcAccessOrder = to_sequence_v<DL_B_TRANSFER.src_access_order>;
|
||||
using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_lengths>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder =
|
||||
to_sequence_v<DL_B_TRANSFER.src_vector_tensor_contiguous_dim_order>;
|
||||
using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
DL_C_TRANSFER.dst_scalar_per_vector;
|
||||
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
|
||||
struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightMultiDXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
ALGORITHM.num_conv_groups_to_merge,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,111 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightTwoStageXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
ALGORITHM.num_conv_groups_to_merge,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE> && Is3D<SIGNATURE>
|
||||
struct ConvBwdWeightWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
LOOP_SCHEDULER,
|
||||
GRIDWISE_GEMM_PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType,
|
||||
ALGORITHM.max_transpose_transfer_src_scalar_per_vector,
|
||||
ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance
|
||||
// of a grouped bwd weight convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct ConvBwdWeightXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = internal::SetBlockGemm<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>, "Invalid A block transfer config");
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>, "Invalid B block transfer config");
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>, "Invalid C block transfer config");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid A thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
"Invalid B thread cluster access order");
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid A source access order");
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
BWD_CONV_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::InComputeType>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -57,6 +57,9 @@
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
// Compile time diagnostics
|
||||
#include "ck_tile/builder/factory/conv_algorithms.hpp"
|
||||
|
||||
// Include all factory implementations
|
||||
#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp"
|
||||
@@ -65,6 +68,15 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/reference_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
@@ -87,56 +99,6 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// Reference algorithm (simplest implementation for validation)
|
||||
template <typename T>
|
||||
concept IsReferenceAlgorithm = ConvAlgorithmDescriptor<T> && requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
SpecifiesTileTransfer<T> && SpecifiesTileConvSpecialization<T> &&
|
||||
SpecifiesTileBlockGemm<T> && SpecifiesTileOptimizations<T>;
|
||||
|
||||
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
|
||||
template <typename T>
|
||||
concept IsXdlV3Algorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesBlockGemm<T>;
|
||||
|
||||
// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply)
|
||||
template <typename T>
|
||||
concept IsXdlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> &&
|
||||
SpecifiesNumGroupsToMerge<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions)
|
||||
template <typename T>
|
||||
concept IsWmmaAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
|
||||
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> && SpecifiesThreadClusterAccessOrder<T> &&
|
||||
SpecifiesSourceAccessOrder<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
|
||||
|
||||
// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts
|
||||
template <typename T>
|
||||
concept IsDlAlgorithm =
|
||||
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesFwdConvSpecialization<T> &&
|
||||
SpecifiesGemmSpecialization<T> && SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
|
||||
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
|
||||
|
||||
// XDL-based kernel with large tensor support
|
||||
template <typename T>
|
||||
concept IsLargeTensorAlgorithm =
|
||||
IsXdlAlgorithm<decltype(T::base_algorithm)> && SpecifiesLargeTensorSupport<T>;
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
@@ -145,35 +107,35 @@ constexpr auto make_conv_instance()
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// Reference algorithm supports all directions
|
||||
if constexpr(IsReferenceAlgorithm<AlgoType>)
|
||||
if constexpr(ReferenceAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// CK Tile supports common factory for each direction
|
||||
else if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
else if constexpr(TileAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// Forward direction (supports most algorithm variants)
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
if constexpr(FwdXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsXdlAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsWmmaAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsDlAlgorithm<AlgoType>)
|
||||
else if constexpr(FwdDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(IsLargeTensorAlgorithm<AlgoType>)
|
||||
else if constexpr(LargeTensorAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvFwdLargeTensorFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
@@ -197,10 +159,55 @@ constexpr auto make_conv_instance()
|
||||
// Backward weight direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(false,
|
||||
"Backward weight convolution: Only reference and tile algorithms "
|
||||
"supported currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
if constexpr(BwdXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdXdlV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdTwoStageXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return
|
||||
typename ConvBwdWeightTwoStageXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdDlAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightDlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDXdlAlgorithm<AlgoType>)
|
||||
{
|
||||
return
|
||||
typename ConvBwdWeightMultiDXdlFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdTwoStageWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightTwoStageWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
|
||||
Instance{};
|
||||
}
|
||||
else if constexpr(BwdWmmaAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightWmmaFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
else if constexpr(BwdMultiDWmmaV3Algorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvBwdWeightMultiDWmmaV3Factory<SIGNATURE, ALGORITHM, VERSION>::
|
||||
Instance{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable backward weight convolution kernel factory found for the provided "
|
||||
"ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, "
|
||||
"XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage "
|
||||
"WMMA V3, WMMA, or Multi-D WMMA V3 variant.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -24,10 +24,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -48,7 +48,7 @@ struct ConvFwdDlFactory
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
@@ -64,7 +64,7 @@ struct ConvFwdDlFactory
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
@@ -80,7 +80,7 @@ struct ConvFwdDlFactory
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
@@ -89,18 +89,18 @@ struct ConvFwdDlFactory
|
||||
// The DL forward convolution kernel class instance
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<
|
||||
SPATIAL_DIM,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Layouts::OutLayout,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
|
||||
@@ -26,68 +26,65 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
internal::SetFwdConvSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<BASE_ALGORITHM>();
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<BASE_ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance with large tensor support.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BASE_ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -106,8 +103,8 @@ struct ConvFwdLargeTensorFactory
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
LOOP_SCHEDULER>;
|
||||
};
|
||||
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
@@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -55,27 +56,27 @@ struct ConvFwdXdlV3Factory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
@@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -108,8 +109,8 @@ struct ConvFwdXdlV3Factory
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
BLOCK_GEMM.scheduler,
|
||||
BLOCK_GEMM.pipeline_version,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
IS_DIRECT_LOAD>;
|
||||
};
|
||||
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -52,27 +52,27 @@ struct ConvFwdWmmaFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
|
||||
@@ -26,10 +26,10 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
@@ -39,6 +39,7 @@ struct ConvFwdXdlFactory
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -50,27 +51,27 @@ struct ConvFwdXdlFactory
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::InDataType,
|
||||
typename Types::WeiDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Types::OutComputeType,
|
||||
typename Types::DsDataType,
|
||||
typename Types::OutDataType,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
@@ -80,10 +81,10 @@ struct ConvFwdXdlFactory
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
@@ -102,10 +103,10 @@ struct ConvFwdXdlFactory
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
typename Types::AComputeType,
|
||||
typename Types::BComputeType,
|
||||
typename Types::InComputeType,
|
||||
typename Types::WeiComputeType,
|
||||
LOOP_SCHEDULER,
|
||||
ALGORITHM.num_groups_to_merge>;
|
||||
ALGORITHM.num_conv_groups_to_merge>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
@@ -10,27 +10,28 @@
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadClusterRank> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadClusterRank> src_access_order{};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
constexpr BlockTransfer<> SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
return BlockTransfer{
|
||||
return BlockTransfer<>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
@@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
};
|
||||
}
|
||||
|
||||
template <auto TRANSFER>
|
||||
constexpr auto SetBwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
|
||||
constexpr auto array_length = block_order.order.size();
|
||||
static_assert(block_order.order.size() == src_order.order.size(),
|
||||
"Mismatched size between block order and src order");
|
||||
|
||||
if constexpr(array_length == 3)
|
||||
{
|
||||
return BlockTransfer<3>{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0],
|
||||
block_order.order[1],
|
||||
block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else if constexpr(array_length == 4)
|
||||
{
|
||||
return BlockTransfer<4>{
|
||||
.thread_cluster_dims = {block_xfer.k_batch_size,
|
||||
block_xfer.k0,
|
||||
block_xfer.m_n,
|
||||
block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0],
|
||||
block_order.order[1],
|
||||
block_order.order[2],
|
||||
block_order.order[3]},
|
||||
.src_access_order = {src_order.order[0],
|
||||
src_order.order[1],
|
||||
src_order.order[2],
|
||||
src_order.order[3]},
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
.lds_padding = lds_cfg.lds_padding,
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Internal error: Unsupported array length");
|
||||
}
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
|
||||
@@ -62,14 +62,15 @@ consteval auto GetElementwiseOp()
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct ElementwiseOps
|
||||
struct ConvElementwiseOps
|
||||
{
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
using AElementwiseOp = typename decltype(input_op)::Op;
|
||||
using BElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using CDEElementwiseOp = typename decltype(output_op)::Op;
|
||||
|
||||
using InElementwiseOp = typename decltype(input_op)::Op;
|
||||
using WeiElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using OutElementwiseOp = typename decltype(output_op)::Op;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTensorLayouts
|
||||
{
|
||||
@@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM,
|
||||
DIR>{};
|
||||
SPATIAL_DIM>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported.");
|
||||
using ALayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM, DIR>())::type;
|
||||
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -33,7 +33,7 @@ struct DataTypeToCK<DataType::FP32>
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeToCK<DataType::INT32>
|
||||
struct DataTypeToCK<DataType::I32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
@@ -156,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes()
|
||||
}
|
||||
|
||||
template <auto Signature>
|
||||
struct FwdConvTensorDataTypes
|
||||
struct ConvTensorDataTypes
|
||||
{
|
||||
static constexpr auto input_types =
|
||||
GetTensorDataAndComputeTypes<Signature.input.config, Signature.data_type>();
|
||||
@@ -165,20 +165,17 @@ struct FwdConvTensorDataTypes
|
||||
static constexpr auto output_types =
|
||||
GetTensorDataAndComputeTypes<Signature.output.config, Signature.data_type>();
|
||||
|
||||
using ADataType = typename decltype(input_types.first)::type;
|
||||
using AComputeType = typename decltype(input_types.second)::type;
|
||||
using BDataType = typename decltype(weight_types.first)::type;
|
||||
using BComputeType = typename decltype(weight_types.second)::type;
|
||||
using InDataType = typename decltype(input_types.first)::type;
|
||||
using InComputeType = typename decltype(input_types.second)::type;
|
||||
using WeiDataType = typename decltype(weight_types.first)::type;
|
||||
using WeiComputeType = typename decltype(weight_types.second)::type;
|
||||
using OutDataType = typename decltype(output_types.first)::type;
|
||||
using OutComputeType = typename decltype(output_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
using EDataType = typename decltype(output_types.first)::type;
|
||||
|
||||
// This is the "compute" type for output.
|
||||
using CShuffleDataType = typename decltype(output_types.second)::type;
|
||||
|
||||
// Data types for the auxiliary tensors (e.g., bias).
|
||||
using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes<Signature>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
@@ -37,7 +38,7 @@ struct BlockGemmSpec
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
@@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler()
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
constexpr auto pipeline_version = ALGORITHM.pipeline_version;
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
@@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unknown ConvFwdSpecialization";
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
|
||||
SetBwdWeightConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.bwd_weight_specialization;
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
case ConvSpecialization::DEFAULT: return ck_conv_spec::Default;
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
case ConvSpecialization::FILTER_3x3:
|
||||
throw "FILTER_3x3 is not supported for backward weight convolution.";
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,11 @@ struct ReferenceFactory
|
||||
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
|
||||
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
using InDataType = typename Types::ADataType;
|
||||
using WeiDataType = typename Types::BDataType;
|
||||
using OutDataType = typename Types::EDataType;
|
||||
using InDataType = typename Types::InDataType;
|
||||
using WeiDataType = typename Types::WeiDataType;
|
||||
using OutDataType = typename Types::OutDataType;
|
||||
|
||||
struct Instance
|
||||
{
|
||||
|
||||
@@ -63,10 +63,7 @@ struct GemmAlgorithmInfo
|
||||
OutputTileTransferInfo c_tile_transfer;
|
||||
builder::PipelineVersion pipeline_version;
|
||||
builder::PipelineScheduler pipeline_scheduler;
|
||||
std::variant<builder::ConvFwdSpecialization,
|
||||
builder::ConvBwdDataSpecialization,
|
||||
builder::ConvBwdWeightSpecialization>
|
||||
conv_specialization;
|
||||
builder::ConvSpecialization conv_specialization;
|
||||
builder::GemmPadding padding;
|
||||
};
|
||||
|
||||
|
||||
@@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction()
|
||||
|
||||
/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type.
|
||||
/// @tparam Instance The device kernel instance type.
|
||||
/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or
|
||||
/// `builder::ConvBwdWeightSpecialization` enum value.
|
||||
/// @return A `builder::ConvSpecialization` enum value.
|
||||
template <typename Instance>
|
||||
constexpr auto conv_spec()
|
||||
{
|
||||
using InstTraits = InstanceTraits<Instance>;
|
||||
using enum builder::ConvSpecialization;
|
||||
|
||||
if constexpr(requires { InstTraits::kConvForwardSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum builder::ConvFwdSpecialization;
|
||||
|
||||
switch(InstTraits::kConvForwardSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
@@ -221,8 +219,6 @@ constexpr auto conv_spec()
|
||||
else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
using enum builder::ConvBwdDataSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdDataSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
@@ -232,8 +228,6 @@ constexpr auto conv_spec()
|
||||
else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; })
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
using enum builder::ConvBwdWeightSpecialization;
|
||||
|
||||
switch(InstTraits::kConvBwdWeightSpecialization)
|
||||
{
|
||||
case Default: return DEFAULT;
|
||||
|
||||
@@ -35,10 +35,10 @@ struct ReferenceCommonTraits
|
||||
typename builder::factory::internal::LayoutToCK<SIGNATURE.output.config.layout>::type;
|
||||
|
||||
// Data types - extract from factory's type helper
|
||||
using Types = builder::factory::internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using EDataType = typename Types::EDataType;
|
||||
using Types = builder::factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using ADataType = typename Types::InDataType;
|
||||
using BDataType = typename Types::WeiDataType;
|
||||
using EDataType = typename Types::OutDataType;
|
||||
using AccDataType = float; // Reference uses float accumulation
|
||||
|
||||
// Elementwise operations - reference only supports PassThrough
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/filter_extent.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
@@ -71,11 +72,10 @@ struct Args<SIGNATURE>
|
||||
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Layouts =
|
||||
factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
|
||||
ConvTensorLengths<SPATIAL_DIM> lengths;
|
||||
|
||||
@@ -89,9 +89,9 @@ struct Args<SIGNATURE>
|
||||
FilterExtent<SPATIAL_DIM> input_left_pad;
|
||||
FilterExtent<SPATIAL_DIM> input_right_pad;
|
||||
|
||||
Ops::AElementwiseOp a_elementwise_op;
|
||||
Ops::BElementwiseOp b_elementwise_op;
|
||||
Ops::CDEElementwiseOp cde_elementwise_op;
|
||||
Ops::InElementwiseOp a_elementwise_op;
|
||||
Ops::WeiElementwiseOp b_elementwise_op;
|
||||
Ops::OutElementwiseOp cde_elementwise_op;
|
||||
|
||||
/// This function returns the `TensorDescriptor` corresponding to
|
||||
/// the input-tensor of the convolution problem. This can then
|
||||
@@ -106,7 +106,7 @@ struct Args<SIGNATURE>
|
||||
// function.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
|
||||
typename Layouts::ALayout>(param);
|
||||
typename Layouts::InLayout>(param);
|
||||
using Extent = typename InputDescriptor::Extent;
|
||||
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -120,7 +120,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
|
||||
typename Layouts::BLayout>(param);
|
||||
typename Layouts::WeiLayout>(param);
|
||||
using Extent = typename WeightDescriptor::Extent;
|
||||
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -134,7 +134,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
|
||||
typename Layouts::ELayout>(param);
|
||||
typename Layouts::OutLayout>(param);
|
||||
using Extent = typename OutputDescriptor::Extent;
|
||||
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -182,6 +182,12 @@ struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* weight;
|
||||
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for forward convolution.
|
||||
@@ -194,68 +200,13 @@ template <auto SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* output;
|
||||
};
|
||||
|
||||
/// @brief `UniqueInputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see UniqueInputs
|
||||
/// @see ValidUniqueInputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct UniqueInputs<SIGNATURE>
|
||||
{
|
||||
DeviceBuffer input_buf;
|
||||
DeviceBuffer weight_buf;
|
||||
|
||||
/// @see ValidUniqueInputs
|
||||
Inputs<SIGNATURE> get()
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
return {
|
||||
.input = input_buf.get(),
|
||||
.weight = weight_buf.get(),
|
||||
};
|
||||
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `UniqueOutputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see UniqueOutputs
|
||||
/// @see ValidUniqueOutputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct UniqueOutputs<SIGNATURE>
|
||||
{
|
||||
DeviceBuffer output_buf;
|
||||
|
||||
/// @see ValidUniqueOutputs
|
||||
Outputs<SIGNATURE> get()
|
||||
{
|
||||
return {
|
||||
.output = output_buf.get(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `alloc_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see alloc_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
|
||||
ValidUniqueInputs<SIGNATURE>
|
||||
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
|
||||
{
|
||||
return {
|
||||
.input_buf = alloc_tensor_buffer(args.make_input_descriptor()),
|
||||
.weight_buf = alloc_tensor_buffer(args.make_weight_descriptor()),
|
||||
};
|
||||
}
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
@@ -269,34 +220,4 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
/// @brief `alloc_outputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see alloc_outputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
|
||||
ValidUniqueOutputs<SIGNATURE>
|
||||
UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
|
||||
{
|
||||
return {
|
||||
.output_buf = alloc_tensor_buffer(args.make_output_descriptor()),
|
||||
};
|
||||
}
|
||||
|
||||
/// @brief `validate()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see validate()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
ValidationReport
|
||||
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected)
|
||||
{
|
||||
ValidationReport report;
|
||||
report.check("output", args.make_output_descriptor(), actual.output, expected.output);
|
||||
return report;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::AElementwiseOp elementwise_a,
|
||||
Ops::BElementwiseOp elementwise_b,
|
||||
Ops::CDEElementwiseOp elementwise_cde) {
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
|
||||
634
experimental/builder/include/ck_tile/builder/testing/debug.hpp
Normal file
634
experimental/builder/include/ck_tile/builder/testing/debug.hpp
Normal file
@@ -0,0 +1,634 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/error.hpp"
|
||||
#include "ck_tile/builder/testing/type_traits.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include <iostream>
|
||||
#include <locale>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <syncstream>
|
||||
#include <concepts>
|
||||
#include <limits>
|
||||
|
||||
/// This file contains a few debugging utilities, mainly focused around
|
||||
/// tensor data. The idea is that the functionality in this file is not
|
||||
/// necessarily used in any testing directly, but is available for the
|
||||
/// programmer to help with debugging problems. These utilities themselves
|
||||
/// should be tested just the same, though, so that they don't undergo
|
||||
/// bitrot while they are not actively being used.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Custom number punctuation for CK-Builder debugging.
|
||||
///
|
||||
/// During debugging, the locale is usually left to the default C locale.
|
||||
/// The C locale does not have any thousands separator, which makes
|
||||
/// large numbers hard to read. This is a specialization of the default
|
||||
/// C++ number punctuation (`std::numpunct`) which separates thousands
|
||||
/// using `'`, which helps getting a quick overview of the magnitude of
|
||||
/// a number. This character is chosen because C++14 allows number literals
|
||||
/// to have this character.
|
||||
///
|
||||
/// @note When using this locale, be sure to restore the old locale in the
|
||||
/// event that the user actually wants to use a non-standard locale.
|
||||
///
|
||||
/// @see std::numpunct
|
||||
struct numpunct : std::numpunct<char>
|
||||
{
|
||||
char do_thousands_sep() const override { return '\''; }
|
||||
|
||||
std::string do_grouping() const override
|
||||
{
|
||||
// See std::numpunct, this separates by thousands.
|
||||
return "\3";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Print information about a tensor descriptor.
|
||||
///
|
||||
/// This function dumps useful information from a tensor descriptor to a
|
||||
/// stream, `std::cout` by default. This includes the number of elements
|
||||
/// in the tensor, the size of the backing space, lengths, strides, etc.
|
||||
///
|
||||
/// @note All information is printed using a lightly modified locale to
|
||||
/// get a unified printing experience. The original locale in `stream` is
|
||||
/// temporarily replaced, but restored before the function returns.
|
||||
///
|
||||
/// @tparam DT The tensor element datatype
|
||||
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
|
||||
///
|
||||
/// @param name A name for the tensor descriptor.
|
||||
/// @param desc The tensor descriptor to print.
|
||||
/// @param out The stream to print to, `std::cout` by default.
|
||||
template <DataType DT, size_t RANK>
|
||||
void print_descriptor(std::string_view name,
|
||||
const TensorDescriptor<DT, RANK>& desc,
|
||||
std::ostream& out = std::cout)
|
||||
{
|
||||
// Create a custom stream with a completely new config (locale,
|
||||
/// precision, fill, etc). Use an osyncstream to buffer the output
|
||||
/// while were at it (its not likely to help a lot, but why not).
|
||||
std::osyncstream stream(out.rdbuf());
|
||||
stream.imbue(std::locale(std::locale(), new detail::numpunct{}));
|
||||
|
||||
// Print name along with some generic info
|
||||
const auto size = desc.get_element_size();
|
||||
const auto space = desc.get_element_space_size();
|
||||
const auto bytes = desc.get_element_space_size_in_bytes();
|
||||
const auto packed = desc.is_packed();
|
||||
|
||||
stream << "Descriptor \"" << name << "\":\n"
|
||||
<< " data type: " << DT << '\n'
|
||||
<< " size: " << size << " elements\n"
|
||||
<< " space: " << space << " elements (" << bytes << " bytes)\n"
|
||||
<< " lengths: " << desc.get_lengths() << '\n'
|
||||
<< " strides: " << desc.get_strides() << '\n'
|
||||
<< " packed: " << (packed ? "yes" : "no") << std::endl;
|
||||
}
|
||||
|
||||
/// @brief User configuration for printing tensors.
|
||||
///
|
||||
/// This structure houses some configuration fields for customizing how tensors
|
||||
/// are printed. The default is usually good, though `TensorPrintConfig::unlimited()`
|
||||
/// is useful if you want to print the entire tensor to the output regardless of size.
|
||||
struct TensorPrintConfig
|
||||
{
|
||||
/// @brief A limit for the number of columns in a tensor row to print.
|
||||
///
|
||||
/// Each row of a tensor will be printed as a sequence of values. At most
|
||||
/// this number of values are printed, if there are more, `row_skip_val`
|
||||
/// will be printed in between.
|
||||
size_t col_limit = 10;
|
||||
|
||||
/// @brief A limit for the number of rows in a 2D matrix to print
|
||||
///
|
||||
/// Tensors with rank higher than 1 are printed as a single matrix or a series
|
||||
/// of matrix slices. At most this number of rows of the matrix will be printed.
|
||||
/// If there are more rows, a row of `matrix_row_skip_val` and possibly
|
||||
/// `row_skip_val` will be printed in between.
|
||||
size_t row_limit = 10;
|
||||
|
||||
/// @brief A limit for the number of 2D tensor slices to print.
|
||||
///
|
||||
/// Tensors with rank higher than 2 are flattened into a sequence of slices. At
|
||||
/// most this number of slices will be printed.
|
||||
size_t slice_limit = 8;
|
||||
|
||||
/// @brief Text to print at the start of a row of values.
|
||||
///
|
||||
/// This is used by `TensorPrinter`, and printed at the start of a row of tensor
|
||||
/// values.
|
||||
std::string_view row_prefix = " ";
|
||||
|
||||
/// @brief Text to print between fields of a row.
|
||||
///
|
||||
/// This is used by `TensorPrinter`, and printed between each value of a row of
|
||||
/// tensor values.
|
||||
std::string_view row_field_sep = " ";
|
||||
|
||||
/// @brief Text to print when skipping some number of row values.
|
||||
///
|
||||
/// This is used by `TensorPrinter`, and printed instead of some number of values
|
||||
/// when the number of values in a row is too large to all print.
|
||||
std::string_view row_skip_val = "...";
|
||||
|
||||
/// @brief Text to print when skipping a row of a matrix.
|
||||
///
|
||||
/// This is used by `TensorPrinter`, and printed instead of a value when some
|
||||
/// number of rows is skipped when printing a matrix. This is similar to
|
||||
/// `row_skip_val`, except in the vertical direction. Note that ALL values
|
||||
/// in the skip row is printed this way.
|
||||
std::string_view matrix_row_skip_val = "...";
|
||||
|
||||
/// @brief The precision of tensor floating point values.
|
||||
///
|
||||
/// Set the number of decimal digits that is printed for a floating point value.
|
||||
int float_precision = 3;
|
||||
|
||||
/// @brief Return the default print config, but without any printing limits.
|
||||
///
|
||||
/// This is useful if you want to print the *entire* tensor, but be aware that
|
||||
/// this may print a lot of data if the tensor is large!
|
||||
constexpr static TensorPrintConfig unlimited()
|
||||
{
|
||||
return {
|
||||
.col_limit = std::numeric_limits<size_t>::max(),
|
||||
.row_limit = std::numeric_limits<size_t>::max(),
|
||||
.slice_limit = std::numeric_limits<size_t>::max(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Iterate over a range of values, but limit the amount of iterations.
|
||||
///
|
||||
/// Iterate over values `0..n`, but if `limit > n`, only iterate over the
|
||||
/// first and last few (`limit // 2)` items. This can be used to iterate over
|
||||
/// large ranges in a way that not too many values are visited. Its primarily
|
||||
/// used when printing tensors so that not all values of a giant tensor are
|
||||
/// dumped to the user's terminal.
|
||||
///
|
||||
/// @param n The total number of items to iterate over.
|
||||
/// @param limit The maximum number of items to iterate over. Use even values
|
||||
/// for best results, as this will lead to the same amount of values in the
|
||||
/// "begin" and "end" sections.
|
||||
/// @param f A functor to invoke for each element. The sole parameter is the
|
||||
/// index.
|
||||
/// @param delim A functor to invoke between the begin and end sections. This
|
||||
/// function is only invoked if any items are skipped at all.
|
||||
void limited_foreach(size_t n, size_t limit, auto f, auto delim)
|
||||
{
|
||||
if(n <= limit)
|
||||
{
|
||||
for(size_t i = 0; i < n; ++i)
|
||||
f(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto begin_count = (limit + 1) / 2; // Round up in case `delim` is odd.
|
||||
const auto end_count = limit / 2;
|
||||
const auto skip_count = n - limit;
|
||||
|
||||
for(size_t i = 0; i < begin_count; ++i)
|
||||
f(i);
|
||||
|
||||
delim(skip_count);
|
||||
|
||||
for(size_t i = n - end_count; i < n; ++i)
|
||||
f(i);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Output stream requirements for use with `TensorPrinter`.
|
||||
///
|
||||
/// The `TensorPrinter` does not write to an ostream directly, but rather writes to
|
||||
/// a custom stream object. This is mainly so that the user of `TensorPrinter` can
|
||||
/// get more details than directly with an ostream. Basically, a valid implementation
|
||||
/// of `TensorPrintStream` exposes 3 things:
|
||||
/// - A way to print (stringified) tensor elements.
|
||||
/// - A way to print arbitrary text messages. These are mostly for formatting. This
|
||||
/// should be implemented using varargs which are directly folded into an ostream,
|
||||
/// so that <iomanip> functions can be used.
|
||||
/// - A way to query the max width of any `val` field.
|
||||
///
|
||||
/// @see TensorPrinter for more information.
|
||||
template <typename Stream>
|
||||
concept TensorPrintStream = requires(Stream& stream, std::string_view val) {
|
||||
{ stream.max_width } -> std::convertible_to<size_t>;
|
||||
{ stream.val(val) } -> std::same_as<void>;
|
||||
{ stream.msg() } -> std::same_as<void>;
|
||||
{ stream.msg("msg") } -> std::same_as<void>;
|
||||
{ stream.msg(std::setw(3), std::setfill(4), "msg", val) } -> std::same_as<void>;
|
||||
};
|
||||
|
||||
/// @brief Utility to print tensors.
|
||||
///
|
||||
/// This structure implements the main logic for printing tensors to a stream.
|
||||
/// In order to help with formatting, the `TensorPrinter` abstracts over a custom
|
||||
/// stream type, see `TensorPrintStream`. This type is actually mostly an internal
|
||||
/// helper and mainly used by `print_tensor`. Its supposed to be constructed
|
||||
/// manually, but see the field docs for what is required.
|
||||
///
|
||||
/// @tparam DT The data type of the tensor to print.
|
||||
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to print.
|
||||
///
|
||||
/// @see print_tensor
|
||||
template <DataType DT, size_t RANK>
|
||||
struct TensorPrinter
|
||||
{
|
||||
/// The name of this tensor. This will be used during printing to add extra
|
||||
/// clarity about what the user is seeing.
|
||||
std::string_view name;
|
||||
|
||||
/// Configuration details of how to print the tensor. This should be able to
|
||||
/// be specified by the user, but the default is good in most cases.
|
||||
TensorPrintConfig config;
|
||||
|
||||
/// The lengths of the tensor to print. These values are directly from
|
||||
/// `TensorDescriptor::get_lengths()`, stored here to avoid querying them
|
||||
/// repeatedly.
|
||||
Extent<RANK> lengths;
|
||||
|
||||
/// The strides of the tensor to print. These values are directly from
|
||||
/// `TensorDescriptor::get_strides()`, stored here to avoid querying them
|
||||
/// repeatedly.
|
||||
Extent<RANK> strides;
|
||||
|
||||
/// The tensor's backing buffer. This memory should be host-accessible, for
|
||||
/// example by copying it back to the host first.
|
||||
const void* h_buffer;
|
||||
|
||||
/// A common stringstream for stringifying tensor values. This is here mostly
|
||||
/// so that we can cache the internal allocation.
|
||||
std::stringstream ss;
|
||||
|
||||
/// @brief Low-level tensor value stringifying function.
|
||||
///
|
||||
/// Print value `value` to the stringstream `ss` (member value). This function
|
||||
/// is the actual low-level printing function that prints each element of the
|
||||
/// tensor. In order to get a robust printing implementation, the value is written
|
||||
/// directly into a stringstream, which is then further processed to be actually
|
||||
/// written to the output. This way, the format doesn't depend on the ostream
|
||||
/// configuration.
|
||||
///
|
||||
/// @param value The value to print to the stream.
|
||||
void stringify_value(const void* value)
|
||||
{
|
||||
if constexpr(DT == DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
ss << "??";
|
||||
return;
|
||||
}
|
||||
|
||||
using CKType = detail::cpp_type_t<DT>;
|
||||
const auto ck_value = *static_cast<const CKType*>(value);
|
||||
|
||||
if constexpr(DT == DataType::I32 || DT == DataType::I8 || DT == DataType::U8)
|
||||
ss << ck_value;
|
||||
else if constexpr(DT == DataType::FP64 || DT == DataType::FP32)
|
||||
ss << std::fixed << std::setprecision(config.float_precision) << ck_value;
|
||||
else if constexpr(DT == DataType::FP16 || DT == DataType::BF16 || DT == DataType::FP8 ||
|
||||
DT == DataType::BF8)
|
||||
ss << std::fixed
|
||||
<< std::setprecision(config.float_precision)
|
||||
// Note: We are using CK types here (cpp_type_t uses DataTypeToCK), so
|
||||
// use CK's type_convert function.
|
||||
<< ::ck::type_convert<float>(ck_value);
|
||||
else
|
||||
// TODO: Tuple types? Currently not implemented in DataTypeToCK...
|
||||
static_assert(false, "stringify_value unsupported data type, please implement");
|
||||
}
|
||||
|
||||
/// @brief Print the value at an index to a stream.
|
||||
///
|
||||
/// This function reads the value at `index` and prints it to `stream` (using
|
||||
/// `stream.val(...)`).
|
||||
///
|
||||
/// @param stream The stream to print to.
|
||||
/// @param index The index in the tensor of the value to print.
|
||||
void print_value(TensorPrintStream auto& stream, const Extent<RANK>& index)
|
||||
{
|
||||
const auto offset = calculate_offset(index, strides);
|
||||
const auto* value_ptr =
|
||||
&static_cast<const std::byte*>(h_buffer)[offset * data_type_sizeof(DT)];
|
||||
|
||||
// Reset the stream without allocating.
|
||||
// ss.str("") allocates...
|
||||
ss.clear();
|
||||
ss.seekg(0);
|
||||
ss.seekp(0);
|
||||
stringify_value(value_ptr);
|
||||
// ss.view() returns a view of the ENTIRE buffer, which may have
|
||||
// lingering data since we used seekp() and seekg() to reset the
|
||||
// stream. For some reason std::stringstream works this way...
|
||||
// Fortunately tellp() returns how many bytes we've actually
|
||||
// written.
|
||||
const auto view = ss.view().substr(0, ss.tellp());
|
||||
stream.val(view);
|
||||
}
|
||||
|
||||
/// @brief Print a 1D row to a stream.
|
||||
///
|
||||
/// Print a row of tensor values to the stream. This function is used for both
|
||||
/// 1D tensors and for rows of 2D tensors, in which the base coordinate is given
|
||||
/// by `index`. Note that the print configuration is taken into account to avoid
|
||||
/// flooding the user's terminal with values.
|
||||
///
|
||||
/// @param stream The stream to print to.
|
||||
/// @param index The index of the row to print. The rightmost index element is
|
||||
/// ignored, as that is the index of the value _within_ the row.
|
||||
void print_row(TensorPrintStream auto& stream, Extent<RANK>& index)
|
||||
{
|
||||
// See note in `print_matrix`.
|
||||
stream.msg(config.row_prefix);
|
||||
limited_foreach(
|
||||
lengths[RANK - 1],
|
||||
config.col_limit,
|
||||
[&](auto i) {
|
||||
stream.msg(config.row_field_sep);
|
||||
index[RANK - 1] = i;
|
||||
print_value(stream, index);
|
||||
},
|
||||
[&]([[maybe_unused]] auto skip_count) {
|
||||
stream.msg(config.row_field_sep);
|
||||
// Note: Not using stream.val(...) here because we don't want this
|
||||
// field to partake in max_width computation, nor do we want to
|
||||
// pad it to the max width.
|
||||
stream.msg(config.row_skip_val);
|
||||
});
|
||||
|
||||
stream.msg('\n');
|
||||
}
|
||||
|
||||
/// @brief Print a 2D matrix to a stream.
|
||||
///
|
||||
/// Print a matrix of tensor values to the stream. This function is used for both
|
||||
/// 2D and slices of higher-dimensional tensors, in which the base coordinate is
|
||||
/// given by `index`. Note that the print configuration is taken into account to
|
||||
/// avoid flooding the user's terminal with values.
|
||||
///
|
||||
/// @param stream The stream to print to.
|
||||
/// @param index The index of the row to print. The 2 rightmost index elements are
|
||||
/// ignored, as those are the indices of values _within_ the matrix.
|
||||
void print_matrix(TensorPrintStream auto& stream, Extent<RANK>& index)
|
||||
{
|
||||
limited_foreach(
|
||||
lengths[RANK - 2],
|
||||
config.row_limit,
|
||||
[&](auto i) {
|
||||
index[RANK - 2] = i;
|
||||
print_row(stream, index);
|
||||
},
|
||||
[&]([[maybe_unused]] auto row_skip_count) {
|
||||
// When we encounter a skip row, continue with the same logic
|
||||
// as printing 1D tensor rows. Instead of actual values, we will
|
||||
// simply print MATRIX_ROW_SKIP_VAL (usually something like "...").
|
||||
stream.msg(config.row_prefix);
|
||||
limited_foreach(
|
||||
lengths[RANK - 1],
|
||||
config.col_limit,
|
||||
[&]([[maybe_unused]] auto i) {
|
||||
stream.msg(config.row_field_sep);
|
||||
// Note: We're using `stream.val(...)` here because we *do* want this field
|
||||
// to partake in max_width computation, and we *do* want to pad it like
|
||||
// value fields. This is so that these appear the same width as actual
|
||||
// values, so that everything is neatly aligned. This also ensures that if
|
||||
// there are no skip values, then the size of the skip field is not taken
|
||||
// into account.
|
||||
stream.val(config.matrix_row_skip_val);
|
||||
},
|
||||
[&]([[maybe_unused]] auto col_skip_count) {
|
||||
stream.msg(config.row_field_sep);
|
||||
// Note: Not using stream.val(...) here because we don't want this
|
||||
// field to partake in max_width computation, nor do we want to
|
||||
// pad it to the max width.
|
||||
stream.msg(config.row_skip_val);
|
||||
});
|
||||
stream.msg('\n');
|
||||
});
|
||||
}
|
||||
|
||||
/// @brief Print a tensor to a stream.
|
||||
///
|
||||
/// This is the main tensor printing function. It calls `print_row` or `print_matrix`
|
||||
/// (possibly repeatedly) as required. This function prints the entire tensor in
|
||||
/// `h_buffer` regardless.
|
||||
///
|
||||
/// @param stream The stream to print to.
|
||||
void print_tensor(TensorPrintStream auto& stream)
|
||||
{
|
||||
Extent<RANK> zero_coord = {};
|
||||
if constexpr(RANK == 0)
|
||||
{
|
||||
// 0D case: just print the one value
|
||||
stream.msg(config.row_prefix);
|
||||
stream.msg(config.row_field_sep);
|
||||
print_value(stream, zero_coord);
|
||||
stream.msg('\n');
|
||||
}
|
||||
else if constexpr(RANK == 1)
|
||||
{
|
||||
// 1D case: dump everything on one line
|
||||
print_row(stream, zero_coord);
|
||||
}
|
||||
else if constexpr(RANK == 2)
|
||||
{
|
||||
// 2D case: print a 2D matrix
|
||||
print_matrix(stream, zero_coord);
|
||||
}
|
||||
else
|
||||
{
|
||||
// For higher dimensions, print each window as a slice
|
||||
// We want to limit the *total* number of slices using `slice_limit`,
|
||||
// not the number in each axis. So flatten the remaining dimensions.
|
||||
// This also avoids recursion in this function in general.
|
||||
|
||||
// First get the shape minus the 2 inner dimensions
|
||||
Extent<RANK - 2> outer_shape;
|
||||
std::copy_n(lengths.begin(), RANK - 2, outer_shape.begin());
|
||||
|
||||
NdIter iter(outer_shape);
|
||||
detail::limited_foreach(
|
||||
iter.numel(),
|
||||
config.slice_limit,
|
||||
[&](auto outer_flat_index) {
|
||||
// Now decode the outer index and turn it back into a complete index
|
||||
const auto outer_index = iter(outer_flat_index);
|
||||
Extent<RANK> index = {};
|
||||
std::copy_n(outer_index.begin(), RANK - 2, index.begin());
|
||||
|
||||
// Print an extra separating line between two slices
|
||||
if(outer_flat_index != 0)
|
||||
stream.msg('\n');
|
||||
|
||||
// Print an information header about the current slice
|
||||
stream.msg("Tensor \"", name, "\", slice [");
|
||||
for(auto x : outer_index)
|
||||
stream.msg(x, ", ");
|
||||
stream.msg(":, :]\n");
|
||||
|
||||
// And print is as matrix
|
||||
print_matrix(stream, index);
|
||||
},
|
||||
[&](auto skip_count) { stream.msg("\n(skipping ", skip_count, " slices...)\n"); });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Implementation of `TensorPrintStream` to figure out the maximum
|
||||
/// width of a field.
|
||||
///
|
||||
/// In order to produce neatly aligned tensors, where all values of each row
|
||||
/// appear on the same columns, we have to figure out the maximum width of
|
||||
/// each field. This print stream helps with that: It does not actually print
|
||||
/// anything, it just figures out the maximum width of any value (not message).
|
||||
///
|
||||
/// @details OK, this function does actually print things, but only to an
|
||||
/// internal `stringstream`. This is so that we can easily figure out the
|
||||
/// width of the field (in bytes), just by counting the amount of bytes
|
||||
/// written into the string stream.
|
||||
///
|
||||
/// @see TensorPrintStream
|
||||
struct MaxFieldWidthStream
|
||||
{
|
||||
size_t max_width = 0;
|
||||
|
||||
/// @brief Print a tensor value to the stream
|
||||
///
|
||||
/// "Print" a value to the stream. This function figures out the width
|
||||
/// of the value when printed, and then composes it with `max_width` to
|
||||
/// figure out the total maximum.
|
||||
///
|
||||
/// @param value The value to print.
|
||||
void val(std::string_view value) { max_width = std::max(max_width, value.size()); }
|
||||
|
||||
/// @brief Print a message to the stream.
|
||||
///
|
||||
/// "Print" a non-value message to the stream. In this implementation,
|
||||
/// everything is discarded.
|
||||
///
|
||||
/// @tparam Args the types of the values to print.
|
||||
///
|
||||
/// @param args The values to print.
|
||||
template <typename... Args>
|
||||
void msg([[maybe_unused]] const Args&... args)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Implementation of `TensorPrintStream` which actually prints.
|
||||
///
|
||||
/// In contrast to `MaxFieldWidthStream`, this function actually prints
|
||||
/// to an ostream, taking the value produced by that type into account.
|
||||
struct OutputStream
|
||||
{
|
||||
std::ostream& stream;
|
||||
// The maximum width of each tensor value.
|
||||
size_t max_width;
|
||||
|
||||
/// @brief Print a tensor value to the stream
|
||||
///
|
||||
/// Actually print a value into the stream, (right-)padding it to
|
||||
/// `max_width`.
|
||||
///
|
||||
/// @param value The value to print.
|
||||
void val(std::string_view value)
|
||||
{
|
||||
stream << std::setfill(' ') << std::setw(max_width) << value;
|
||||
}
|
||||
|
||||
/// @brief Print a message to the stream.
|
||||
///
|
||||
/// This prints a non-value message directly to the ostream, as if
|
||||
/// folded via `operator<<`.
|
||||
///
|
||||
/// @tparam Args the types of the values to print.
|
||||
///
|
||||
/// @param args The values to print.
|
||||
template <typename... Args>
|
||||
void msg(const Args&... args)
|
||||
{
|
||||
(stream << ... << args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Print device tensor values to an ostream.
|
||||
///
|
||||
/// Print the values of a tensor to an ostream. This function neatly formats
|
||||
/// the tensor according to `config`, tabulating the values so that they are
|
||||
/// vertically aligned and skipping values to prevent flooding the terminal.
|
||||
/// With the default config, this function is good to get a quick overview
|
||||
/// of what a tensor looks like. For a more complete overview, consider
|
||||
/// supplying `TensorPrintConfig::unlimited()` to get everything (but beware
|
||||
/// of flooding the terminal). Tensors are printed with the rightmost-dimension
|
||||
/// as inner dimension, these values appear on the same row in the output.
|
||||
///
|
||||
/// @tparam DT The data type of the tensor.
|
||||
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
|
||||
///
|
||||
/// @param name A name for the tensor. This will be used to add some extra identifying
|
||||
/// information during printing.
|
||||
/// @param desc The descriptor for the tensor memory layout.
|
||||
/// @param d_buffer The tensor's actual data buffer. This is expected to be
|
||||
/// _device accessible_ memory, as its copied back to the host first.
|
||||
/// @param config Tensor printing configuration. This allows tweaking some details
|
||||
/// of the printing process.
|
||||
/// @param out The ostream to print to, `std::cout` by default.
|
||||
template <DataType DT, size_t RANK>
|
||||
void print_tensor(std::string_view name,
|
||||
const TensorDescriptor<DT, RANK>& desc,
|
||||
const void* d_buffer,
|
||||
TensorPrintConfig config = {},
|
||||
std::ostream& out = std::cout)
|
||||
{
|
||||
// Copy memory to the host (printing from device is sketchy)
|
||||
const auto space = desc.get_element_space_size_in_bytes();
|
||||
std::vector<std::byte> h_buffer(space);
|
||||
check_hip(hipMemcpy(h_buffer.data(), d_buffer, space, hipMemcpyDeviceToHost));
|
||||
|
||||
// Create a custom stream with a completely new config (locale,
|
||||
/// precision, fill, etc). Use an osyncstream to buffer the output
|
||||
/// while were at it (its not likely to help a lot, but why not).
|
||||
std::osyncstream stream(out.rdbuf());
|
||||
stream.imbue(std::locale(std::locale(), new detail::numpunct{}));
|
||||
|
||||
// Print a header for the entire tensor (regardless of if there are multiple slices).
|
||||
stream << "Tensor \"" << name << "\": shape = " << desc.get_lengths() << "\n";
|
||||
|
||||
detail::TensorPrinter<DT, RANK> printer = {
|
||||
.name = name,
|
||||
.config = config,
|
||||
.lengths = desc.get_lengths(),
|
||||
.strides = desc.get_strides(),
|
||||
.h_buffer = h_buffer.data(),
|
||||
.ss = std::stringstream(),
|
||||
};
|
||||
|
||||
// We're actually going to print twice: once to figure out the
|
||||
// maximum width of the fields, and once to actually print to the stream.
|
||||
|
||||
// Print once to figure out the maximum field width.
|
||||
detail::MaxFieldWidthStream max_field_width;
|
||||
printer.print_tensor(max_field_width);
|
||||
|
||||
// Actually print to the output stream.
|
||||
detail::OutputStream tensor_out = {
|
||||
.stream = stream,
|
||||
.max_width = max_field_width.max_width,
|
||||
};
|
||||
printer.print_tensor(tensor_out);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -81,4 +81,15 @@ inline DeviceBuffer alloc_buffer(size_t size)
|
||||
return DeviceBuffer(d_buf);
|
||||
}
|
||||
|
||||
/// @brief "Align" an offset to a multiple of a particular alignment.
|
||||
///
|
||||
/// Returns `addr` aligned to the next multiple of `alignment`.
|
||||
///
|
||||
/// @param addr The address to align.
|
||||
/// @param alignment The alignment.
|
||||
inline size_t align_fwd(size_t addr, size_t alignment)
|
||||
{
|
||||
return addr % alignment == 0 ? addr : addr - addr % alignment + alignment;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <iosfwd>
|
||||
#include <concepts>
|
||||
#include <algorithm>
|
||||
#include <hip/hip_runtime.h>
|
||||
@@ -123,6 +124,33 @@ struct Extent : std::array<size_t, RANK>
|
||||
template <typename... T>
|
||||
Extent(T...) -> Extent<sizeof...(T)>;
|
||||
|
||||
/// @brief Extent printer
|
||||
///
|
||||
/// This function implements an ostream printing overload for `Extent`, so that
|
||||
/// they can be printed in the usual `stream << extent` fashion.
|
||||
///
|
||||
/// @tparam RANK Rank (number of spatial dimensions) of the extent.
|
||||
///
|
||||
/// @param stream The stream to print the extent to.
|
||||
/// @param extent The extent to print to the stream.
|
||||
template <size_t RANK>
|
||||
std::ostream& operator<<(std::ostream& stream, const Extent<RANK>& extent)
|
||||
{
|
||||
stream << '[';
|
||||
bool first = true;
|
||||
for(const auto x : extent)
|
||||
{
|
||||
if(first)
|
||||
first = false;
|
||||
else
|
||||
stream << ", ";
|
||||
|
||||
stream << x;
|
||||
}
|
||||
|
||||
return stream << ']';
|
||||
}
|
||||
|
||||
/// @brief Concept for automatically deriving tensor memory layout.
|
||||
///
|
||||
/// A `TensorStridesGenerator` is a type which can be used to automatically
|
||||
|
||||
@@ -18,6 +18,102 @@
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief Utility structure for N-dimensional iteration using a flat index
|
||||
///
|
||||
/// This structure's main purpose is to "unmerge" a flattened index into a
|
||||
/// multi-dimensional index, which helps when iterating over multi-dimensional
|
||||
/// indices without having to write an arbitrary amount of nested for loops.
|
||||
/// A minimal amount of precomputation must be done to do this efficiently,
|
||||
/// which is handled in the constructor of this type.
|
||||
///
|
||||
/// @details Decoding a flat index into a multi-dimensional index is done by
|
||||
/// first computing a reverse scan of the shape. These values can then be
|
||||
/// used to decode the index in the usual way:
|
||||
///
|
||||
/// x = flat_idx / (size_y * size_z)
|
||||
/// y = flat_idx % (size_y * size_z) / size_z
|
||||
/// z = flat_idx % (size_y * size_z) % size_z
|
||||
/// etc
|
||||
///
|
||||
/// The decode order is such that the innermost dimension (right in
|
||||
/// the shape extent) changes the fastest.
|
||||
///
|
||||
/// @tparam RANK The rank (number of spatial dimensions) of the tensor to
|
||||
/// iterate.
|
||||
template <size_t RANK>
|
||||
struct NdIter
|
||||
{
|
||||
/// @brief Prepare N-dimensional iteration over a particular shape.
|
||||
///
|
||||
/// Precompute ashape into a form that can be used to easily decode a flat
|
||||
/// index into a multi-dimensional index.
|
||||
///
|
||||
/// @param shape The shape to iterate over.
|
||||
explicit NdIter(const Extent<RANK>& shape)
|
||||
{
|
||||
// Precompute shape_scan = [..., shape[-2] * shape[-1], shape[-1], 1]
|
||||
|
||||
numel_ = 1;
|
||||
for(int i = RANK; i > 0; --i)
|
||||
{
|
||||
shape_scan_[i - 1] = numel_;
|
||||
numel_ *= shape[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Unflatten a flat index into a multi-dimensional index
|
||||
///
|
||||
/// This applies the usual multi-dimensional indexing method over the
|
||||
/// precomputed shape scan to get back a multi-dimensional index.
|
||||
/// The decode order is such that the innermost dimension (right in
|
||||
/// the shape extent) changes the fastest.
|
||||
///
|
||||
/// @param flat_index The "flattened" (1-dimensional) index of the tensor
|
||||
///
|
||||
/// @returns A multi-dimensional index into the tensor
|
||||
///
|
||||
/// @pre `0 <= flat_index < size()` (in other words, the `flat_index` must
|
||||
/// be in bounds of the tensor shape that this `NdIter` was made from).
|
||||
__host__ __device__ Extent<RANK> operator()(size_t flat_index) const
|
||||
{
|
||||
Extent<RANK> index = {};
|
||||
auto idx = flat_index;
|
||||
for(size_t i = 0; i < RANK; ++i)
|
||||
{
|
||||
const auto scanned_dim = shape_scan_[i];
|
||||
index[i] = idx / scanned_dim;
|
||||
idx %= scanned_dim;
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
/// @brief Return the total elements to iterate over
|
||||
///
|
||||
/// Get the total number of elements in the shape to iterate over. This value
|
||||
/// can be used to construct a complete for loop to iterate over all indices
|
||||
/// of a tensor, for example:
|
||||
///
|
||||
/// for(size_t i = 0; i < iter.numel(); ++i)
|
||||
/// {
|
||||
/// const auto index = iter(i);
|
||||
/// use(index);
|
||||
/// }
|
||||
__host__ __device__ size_t numel() const { return numel_; }
|
||||
|
||||
private:
|
||||
/// Reverse (right) scan of the shape to iterate over.
|
||||
Extent<RANK> shape_scan_;
|
||||
|
||||
/// The total number of elements in the shape. This value turns out to be almost
|
||||
/// always required when iterating over a shape, so just store it in this type
|
||||
/// so that it is easily accessible.
|
||||
size_t numel_;
|
||||
};
|
||||
|
||||
template <size_t RANK>
|
||||
NdIter(Extent<RANK>) -> NdIter<RANK>;
|
||||
|
||||
/// @brief Concept for constraining tensor iteration functors.
|
||||
///
|
||||
/// This concept checks that a functor has the correct signature for
|
||||
@@ -50,28 +146,19 @@ constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256;
|
||||
/// @tparam F The type of the callback to invoke. This function must be
|
||||
/// compatible with execution as a __device__ function.
|
||||
///
|
||||
/// @param numel The total number of elements in the tensor.
|
||||
/// @param shape_scan A right-exclusive scan of the shape of the tensor.
|
||||
/// @param iter An NdIter instance to help iterating over the tensor.
|
||||
/// @param f The callback to invoke for each index of the tensor. This
|
||||
/// functor must be eligible for running on the GPU.
|
||||
template <int BLOCK_SIZE, size_t RANK, typename F>
|
||||
requires ForeachFunctor<F, RANK>
|
||||
__global__ __launch_bounds__(BLOCK_SIZE) //
|
||||
void foreach_kernel(const size_t numel, Extent<RANK> shape_scan, F f)
|
||||
void foreach_kernel(NdIter<RANK> iter, F f)
|
||||
{
|
||||
const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
||||
for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE)
|
||||
for(size_t flat_idx = gid; flat_idx < iter.numel(); flat_idx += gridDim.x * BLOCK_SIZE)
|
||||
{
|
||||
// Compute the current index.
|
||||
Extent<RANK> index = {};
|
||||
|
||||
size_t idx = flat_idx;
|
||||
for(size_t i = 0; i < RANK; ++i)
|
||||
{
|
||||
const auto scanned_dim = shape_scan[i];
|
||||
index[i] = idx / scanned_dim;
|
||||
idx %= scanned_dim;
|
||||
}
|
||||
const auto index = iter(flat_idx);
|
||||
|
||||
// Then invoke the callback with the index.
|
||||
f(index);
|
||||
@@ -160,18 +247,12 @@ void tensor_foreach(const Extent<RANK>& shape, ForeachFunctor<RANK> auto f)
|
||||
// order in the kernel is from large-to-small. Right layout is the
|
||||
// easiest solution for that.
|
||||
|
||||
Extent<RANK> shape_scan;
|
||||
size_t numel = 1;
|
||||
for(int i = RANK; i > 0; --i)
|
||||
{
|
||||
shape_scan[i - 1] = numel;
|
||||
numel *= shape[i - 1];
|
||||
}
|
||||
NdIter iter(shape);
|
||||
|
||||
// Reset any errors from previous launches.
|
||||
(void)hipGetLastError();
|
||||
|
||||
kernel<<<occupancy * multiprocessors, block_size>>>(numel, shape_scan, f);
|
||||
kernel<<<occupancy * multiprocessors, block_size>>>(iter, f);
|
||||
check_hip(hipGetLastError());
|
||||
}
|
||||
|
||||
@@ -179,7 +260,7 @@ void tensor_foreach(const Extent<RANK>& shape, ForeachFunctor<RANK> auto f)
|
||||
///
|
||||
/// This concept checks that a functor has the correct signature for
|
||||
/// use with the `fill_tensor` function.
|
||||
template <typename F, builder::DataType DT, size_t RANK>
|
||||
template <typename F, DataType DT, size_t RANK>
|
||||
concept FillTensorFunctor = requires(const F& f, const Extent<RANK>& index) {
|
||||
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
|
||||
};
|
||||
@@ -199,7 +280,7 @@ concept FillTensorFunctor = requires(const F& f, const Extent<RANK>& index) {
|
||||
/// @param f A functor used to get the value at a particular coordinate.
|
||||
///
|
||||
/// @see FillTensorFunctor
|
||||
template <builder::DataType DT, size_t RANK>
|
||||
template <DataType DT, size_t RANK>
|
||||
void fill_tensor(const TensorDescriptor<DT, RANK>& desc,
|
||||
void* buffer,
|
||||
FillTensorFunctor<DT, RANK> auto f)
|
||||
@@ -218,7 +299,7 @@ void fill_tensor(const TensorDescriptor<DT, RANK>& desc,
|
||||
///
|
||||
/// This concept checks that a functor has the correct signature for
|
||||
/// use with the `fill_tensor_buffer` function.
|
||||
template <typename F, builder::DataType DT>
|
||||
template <typename F, DataType DT>
|
||||
concept FillTensorBufferFunctor = requires(const F& f, size_t index) {
|
||||
{ f(index) } -> std::convertible_to<detail::cpp_type_t<DT>>;
|
||||
};
|
||||
@@ -239,7 +320,7 @@ concept FillTensorBufferFunctor = requires(const F& f, size_t index) {
|
||||
/// @param f A functor used to get the value at a particular index.
|
||||
///
|
||||
/// @see FillTensorBufferFunctor
|
||||
template <builder::DataType DT, size_t RANK>
|
||||
template <DataType DT, size_t RANK>
|
||||
void fill_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
|
||||
void* buffer,
|
||||
FillTensorBufferFunctor<DT> auto f)
|
||||
@@ -247,7 +328,19 @@ void fill_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
|
||||
fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); });
|
||||
}
|
||||
|
||||
template <builder::DataType DT, size_t RANK>
|
||||
/// @brief Utility for clearing tensor buffers to a particular value.
|
||||
///
|
||||
/// This function initializes all memory backing a particular tensor buffer to
|
||||
/// one specific value, zero by default. Note that this function ignores strides,
|
||||
/// and clears the entire buffer backing the tensor.
|
||||
///
|
||||
/// @tparam DT The tensor element datatype
|
||||
/// @tparam RANK The rank (number of spatial dimensions) of the tensor.
|
||||
///
|
||||
/// @param desc The descriptor of the tensor to initialize.
|
||||
/// @param buffer The memory of the tensor to initialize.
|
||||
/// @param value The value to initialize the tensor buffer with.
|
||||
template <DataType DT, size_t RANK>
|
||||
void clear_tensor_buffer(const TensorDescriptor<DT, RANK>& desc,
|
||||
void* buffer,
|
||||
detail::cpp_type_t<DT> value = detail::cpp_type_t<DT>{0})
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include <concepts>
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/builder/testing/validation.hpp"
|
||||
|
||||
/// This file is the main header for the CK-Builder testing system. A high-level
|
||||
@@ -132,8 +134,8 @@ struct Outputs;
|
||||
/// be created using `alloc_inputs()` and that an instance of the corresponding
|
||||
/// `Inputs` structure can be obtained using `.get()`.
|
||||
///
|
||||
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
|
||||
/// type to allocate individual device buffers for each input tensor.
|
||||
/// @note A default implementation is provided for this type if `Inputs`
|
||||
/// supports `TensorReflectable`.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
///
|
||||
@@ -151,8 +153,8 @@ struct UniqueInputs;
|
||||
/// be created using `alloc_outputs()` and that an instance of the corresponding
|
||||
/// `Outputs` structure can be obtained using `.get()`.
|
||||
///
|
||||
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
|
||||
/// type to allocate individual device buffers for each output tensor.
|
||||
/// @note A default implementation is provided for this type if `Outputs`
|
||||
/// supports `TensorReflectable`.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
///
|
||||
@@ -197,6 +199,12 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
|
||||
/// amount of memory required and then allocate it on the device, for example
|
||||
/// using `alloc_buffer` or `alloc_tensor_buffer`.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @note A default implementation is provided for this function if `Inputs`
|
||||
/// supports `TensorReflectable`.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
@@ -207,22 +215,22 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
|
||||
/// @see alloc_tensor_buffer()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidUniqueInputs<SIGNATURE>
|
||||
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
|
||||
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args) = delete;
|
||||
|
||||
/// @brief Allocate inputs corresponding to a signature.
|
||||
/// @brief Initialize inputs corresponding to a signature.
|
||||
///
|
||||
/// The `init_inputs()` function is used to initialize pseudo-random data
|
||||
/// to the tensors specified in the Inputs structure. Implementors should
|
||||
/// fill each of the tensors in `inputs` with appropriate random data.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @tparam SIGNATURE the signature to specialize the structure for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
/// @param inputs The operation inputs to initialize with random data.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see Inputs
|
||||
/// @see tensor_initialization
|
||||
template <auto SIGNATURE>
|
||||
@@ -235,13 +243,16 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete
|
||||
/// amount of memory required and then allocate it on the device, for example
|
||||
/// using `alloc_buffer` or `alloc_tensor_buffer`.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @note A default implementation is provided for this function if `Outputs`
|
||||
/// supports `TensorReflectable`.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see Outputs
|
||||
/// @see UniqueOutputs
|
||||
/// @see alloc_buffer()
|
||||
@@ -262,15 +273,15 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
|
||||
/// were incorrect, and where (a subset of) those elements are located within
|
||||
/// the tensor. See `ValidationReport` for more information about the report.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
/// @param actual The actual results, the results of the operation to-be-tested.
|
||||
/// @param expected The expected results, the results of the reference implementation.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see ValidationReport
|
||||
template <auto SIGNATURE>
|
||||
ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string_view>
|
||||
|
||||
/// testing.hpp requires developers of a type of SIGNATURE to implement
|
||||
/// quite a lot of functionality for each SIGNATURE. For example, next
|
||||
/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define
|
||||
/// `UniqueInputs`, `UniqueOutputs`, `alloc_inputs`, `alloc_outputs`,
|
||||
/// and `validate`. The implementation of these latter few functions
|
||||
/// is usually quite straight forward and adds a bunch of copy-paste
|
||||
/// overhead. The functionality in this file offers an alternative
|
||||
/// route: By implementing some reflection functionality in `Inputs`
|
||||
/// and `Outputs`, we can automatically derive most of the functionality.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief Check whether an `Input` or `Output` struct can be reflected.
|
||||
///
|
||||
/// In order to avoid having to manually redefine a bunch of types related to
|
||||
/// each `Inputs`/`Outputs` structure, those structures can also provide some
|
||||
/// "reflection" functionality. To this end, they should implement
|
||||
/// `static void reflect(const Args<SIGNATURE> args&, auto inspect)`, where `inspect`
|
||||
/// is called with information about each field in the struct. In more detail,
|
||||
/// the signature of the `inspect` function is as follows:
|
||||
///
|
||||
/// void inspect(
|
||||
/// // A human-readable name for the tensor
|
||||
/// std::string_view name,
|
||||
/// // Descriptor for the tensor in memory, usually obtained via `args`.
|
||||
/// const TensorDescriptor<DT, RANK>& desc,
|
||||
/// // Member pointer to a field of `T`, which is a GPU-memory pointer
|
||||
/// // to the relevant tensor memory.
|
||||
/// void* T::* ptr);
|
||||
///
|
||||
/// Here, `T` is `Inputs<SIGNATURE>` or `Outputs<SIGNATURE>`.
|
||||
///
|
||||
/// @see Inputs
|
||||
/// @see Outputs
|
||||
template <typename T, auto SIGNATURE>
|
||||
concept TensorReflectable = requires(const Args<SIGNATURE>& args) {
|
||||
{
|
||||
T::reflect(args,
|
||||
[]([[maybe_unused]] std::string_view name,
|
||||
// Note: This will be a TensorDescriptor<DT, RANK>, but the actual
|
||||
// DT and RANK may differ depending on member.
|
||||
[[maybe_unused]] const auto& desc,
|
||||
[[maybe_unused]] void* T::*ptr) {})
|
||||
};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// The default alignment between tensors allocated separately
|
||||
/// by `UniqueTensors`. This should be large enough to accomodate
|
||||
/// any type. hipMalloc returns an alignment of 256 by default.
|
||||
constexpr size_t TENSOR_ALIGNMENT = 256;
|
||||
|
||||
/// @brief Common type for automatically managing memory of sets of tensors.
|
||||
///
|
||||
/// This type implements the automatic memory management logic for `Inputs` and
|
||||
/// `Outputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize the structure for.
|
||||
/// @tparam Tensors The `Inputs` or `Outputs` type corresponding to `SIGNATURE`.
|
||||
template <auto SIGNATURE, typename Tensors>
|
||||
requires TensorReflectable<Tensors, SIGNATURE>
|
||||
struct UniqueTensors
|
||||
{
|
||||
/// @brief Allocate tensors.
|
||||
///
|
||||
/// This function computes the total size of memory to allocate according to
|
||||
/// the tensors in `args`, and then allocates it as a continuous buffer.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
explicit UniqueTensors(const Args<SIGNATURE>& args)
|
||||
{
|
||||
// First compute the total size of all tensors combined
|
||||
size_t total_size = 0;
|
||||
Tensors::reflect(args,
|
||||
[&, this]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
[[maybe_unused]] void* Tensors::*ptr) {
|
||||
total_size = align_fwd(total_size, TENSOR_ALIGNMENT);
|
||||
total_size += desc.get_element_space_size_in_bytes();
|
||||
});
|
||||
|
||||
data_ = alloc_buffer(total_size);
|
||||
|
||||
// Now assign the pointers based on the same offsets that
|
||||
// we computed in the first loop.
|
||||
size_t offset = 0;
|
||||
Tensors::reflect(args,
|
||||
[&, this]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
[[maybe_unused]] void* Tensors::*ptr) {
|
||||
offset = align_fwd(offset, TENSOR_ALIGNMENT);
|
||||
tensors_.*ptr = data_.get() + offset;
|
||||
offset += desc.get_element_space_size_in_bytes();
|
||||
});
|
||||
}
|
||||
|
||||
/// @brief Return raw `Inputs` or `Outputs` type.
|
||||
///
|
||||
/// @see ValidUniqueInputs
|
||||
/// @see ValidUniqueOutputs
|
||||
Tensors get() const { return tensors_; }
|
||||
|
||||
private:
|
||||
/// Owning pointer of input memory
|
||||
DeviceBuffer data_;
|
||||
/// Struct with pointers to each tensor. Stored here so that we
|
||||
/// don't need to keep recomputing it.
|
||||
Tensors tensors_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Implementation of `UniqueInputs` for `Inputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize for.
|
||||
///
|
||||
/// @see UniqueInputs
|
||||
template <auto SIGNATURE>
|
||||
requires TensorReflectable<Inputs<SIGNATURE>, SIGNATURE>
|
||||
struct UniqueInputs<SIGNATURE> : detail::UniqueTensors<SIGNATURE, Inputs<SIGNATURE>>
|
||||
{
|
||||
using detail::UniqueTensors<SIGNATURE, Inputs<SIGNATURE>>::UniqueTensors;
|
||||
};
|
||||
|
||||
/// @brief Implementation of `UniqueOutputs` for `Outputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize for.
|
||||
///
|
||||
/// @see UniqueOutputs
|
||||
template <auto SIGNATURE>
|
||||
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
|
||||
struct UniqueOutputs<SIGNATURE> : detail::UniqueTensors<SIGNATURE, Outputs<SIGNATURE>>
|
||||
{
|
||||
using detail::UniqueTensors<SIGNATURE, Outputs<SIGNATURE>>::UniqueTensors;
|
||||
};
|
||||
|
||||
/// @brief Implementation of `alloc_inputs` for `Inputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
///
|
||||
/// @see alloc_inputs
|
||||
template <auto SIGNATURE>
|
||||
requires TensorReflectable<Inputs<SIGNATURE>, SIGNATURE>
|
||||
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args)
|
||||
{
|
||||
static_assert(ValidUniqueInputs<SIGNATURE>, "sanity check");
|
||||
return UniqueInputs<SIGNATURE>(args);
|
||||
}
|
||||
|
||||
/// @brief Implementation of `alloc_outputs` for `Outputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
///
|
||||
/// @see alloc_outputs
|
||||
template <auto SIGNATURE>
|
||||
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
|
||||
UniqueOutputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args)
|
||||
{
|
||||
static_assert(ValidUniqueOutputs<SIGNATURE>, "sanity check");
|
||||
return UniqueOutputs<SIGNATURE>(args);
|
||||
}
|
||||
|
||||
/// @brief Implementation of `validate` for `Outputs` that support reflection.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature to specialize for.
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
/// @param actual The actual results, the results of the operation to-be-tested.
|
||||
/// @param expected The expected results, the results of the reference implementation.
|
||||
///
|
||||
/// @see alloc_outputs
|
||||
template <auto SIGNATURE>
|
||||
requires TensorReflectable<Outputs<SIGNATURE>, SIGNATURE>
|
||||
ValidationReport
|
||||
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected)
|
||||
{
|
||||
ValidationReport report;
|
||||
|
||||
Outputs<SIGNATURE>::reflect(
|
||||
args, [&](std::string_view name, const auto& desc, void* Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(name, desc, actual.*ptr, expected.*ptr);
|
||||
});
|
||||
|
||||
return report;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -39,7 +39,7 @@ constexpr size_t data_type_sizeof(DataType data_type)
|
||||
case DataType::FP8: return 1;
|
||||
case DataType::BF8: return 1;
|
||||
case DataType::FP64: return 8;
|
||||
case DataType::INT32: return 4;
|
||||
case DataType::I32: return 4;
|
||||
case DataType::I8: return 1;
|
||||
case DataType::I8_I8: return 2;
|
||||
case DataType::U8: return 1;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
@@ -24,7 +24,7 @@ enum class DataType
|
||||
FP8,
|
||||
BF8,
|
||||
FP64,
|
||||
INT32,
|
||||
I32,
|
||||
I8,
|
||||
I8_I8,
|
||||
U8
|
||||
@@ -192,8 +192,8 @@ enum class TileConvSpecialization
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
// Enums for the convolution specializations.
|
||||
enum class ConvSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
@@ -202,22 +202,6 @@ enum class ConvFwdSpecialization
|
||||
ODD_C
|
||||
};
|
||||
|
||||
// Enums for the backward data convolution specialization.
|
||||
enum class ConvBwdDataSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
};
|
||||
|
||||
// Enums for the backward weight convolution specialization.
|
||||
enum class ConvBwdWeightSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_1X1_PAD0,
|
||||
ODD_C,
|
||||
};
|
||||
|
||||
// Enums for the Gemm padding.
|
||||
enum class GemmPadding
|
||||
{
|
||||
@@ -249,11 +233,13 @@ enum class PipelineScheduler
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR,
|
||||
REFERENCE // GPU reference implementation for validation
|
||||
REFERENCE, // GPU reference implementation for validation,
|
||||
TWO_STAGE,
|
||||
MULTIPLE_D
|
||||
};
|
||||
|
||||
// toString methods for enum classes
|
||||
inline std::string_view toString(DataType dt)
|
||||
// to_string methods for enum classes
|
||||
inline std::string_view to_string(DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
@@ -267,7 +253,7 @@ inline std::string_view toString(DataType dt)
|
||||
case FP8: return "FP8";
|
||||
case BF8: return "BF8";
|
||||
case FP64: return "FP64";
|
||||
case INT32: return "INT32";
|
||||
case I32: return "I32";
|
||||
case I8: return "I8";
|
||||
case I8_I8: return "I8_I8";
|
||||
case U8: return "U8";
|
||||
@@ -276,7 +262,7 @@ inline std::string_view toString(DataType dt)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(ConvDirection dir)
|
||||
inline std::string_view to_string(ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
@@ -288,7 +274,7 @@ inline std::string_view toString(ConvDirection dir)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(ElementwiseOperation op)
|
||||
inline std::string_view to_string(ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
@@ -332,7 +318,7 @@ inline std::string_view toString(ElementwiseOperation op)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(PipelineVersion ver)
|
||||
inline std::string_view to_string(PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
@@ -347,7 +333,7 @@ inline std::string_view toString(PipelineVersion ver)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(GemmSpecialization spec)
|
||||
inline std::string_view to_string(GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
@@ -372,9 +358,9 @@ inline std::string_view toString(GemmSpecialization spec)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(ConvFwdSpecialization spec)
|
||||
inline std::string_view to_string(ConvSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
using enum ConvSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
@@ -386,31 +372,7 @@ inline std::string_view toString(ConvFwdSpecialization spec)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case ODD_C: return "ODD_C";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(GemmPadding padding)
|
||||
inline std::string_view to_string(GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
@@ -435,7 +397,7 @@ inline std::string_view toString(GemmPadding padding)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(PipelineScheduler sched)
|
||||
inline std::string_view to_string(PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
@@ -447,7 +409,7 @@ inline std::string_view toString(PipelineScheduler sched)
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string_view toString(TensorLayout layout)
|
||||
inline std::string_view to_string(TensorLayout layout)
|
||||
{
|
||||
using enum TensorLayout;
|
||||
switch(layout)
|
||||
@@ -503,63 +465,46 @@ inline std::string_view toString(TensorLayout layout)
|
||||
}
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); }
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << to_string(dt); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); }
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
{
|
||||
return os << to_string(dir);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
return os << toString(op);
|
||||
return os << to_string(op);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
return os << toString(ver);
|
||||
return os << to_string(ver);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
return os << to_string(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
return os << to_string(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
return os << toString(padding);
|
||||
return os << to_string(padding);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
return os << toString(sched);
|
||||
return os << to_string(sched);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
{
|
||||
return os << toString(layout);
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
ConvBwdDataSpecialization,
|
||||
ConvBwdWeightSpecialization>& spec)
|
||||
{
|
||||
std::visit([&os](const auto& s) { os << s; }, spec);
|
||||
return os;
|
||||
return os << to_string(layout);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -83,11 +83,14 @@ add_ck_builder_test(test_ckb_conv_builder
|
||||
unit_tensor_foreach.cpp
|
||||
unit_error.cpp
|
||||
unit_validation.cpp
|
||||
unit_debug.cpp
|
||||
unit_conv_fwd_testing.cpp
|
||||
unit_conv_elementwise_op.cpp
|
||||
unit_conv_tensor_layout.cpp
|
||||
unit_conv_tensor_type.cpp
|
||||
unit_conv_thread_block.cpp
|
||||
unit_conv_tuning_params.cpp)
|
||||
target_link_libraries(test_ckb_conv_builder PRIVATE utility)
|
||||
|
||||
# Tests the inline diff utility used for comparing strings in tests assertions
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
@@ -121,7 +124,7 @@ add_ck_builder_test(test_ckb_conv_description
|
||||
# Verifies that GetInstanceString() methods and other functions produce valid kernel code.
|
||||
# Tests various convolution types:
|
||||
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
|
||||
# - Backward weight group convolution (XDL)
|
||||
# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants)
|
||||
# Requires kernel compilation to validate the generated strings through the base class.
|
||||
|
||||
set(INSTANCE_STRING_TESTS
|
||||
@@ -164,10 +167,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp)
|
||||
)
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
set(BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
|
||||
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
|
||||
)
|
||||
|
||||
if (CK_USE_WMMA)
|
||||
list(APPEND BWD_WEIGHT_TESTS
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp
|
||||
conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS})
|
||||
target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_bwd_data_instances
|
||||
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
|
||||
)
|
||||
target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility)
|
||||
|
||||
|
||||
################################################################################
|
||||
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
|
||||
@@ -221,6 +249,8 @@ endforeach()
|
||||
set(CKB_REGRESSION_TESTS
|
||||
test_ckb_instance_string
|
||||
test_ckb_build_fwd_instances
|
||||
test_ckb_build_bwd_weight_instances
|
||||
test_ckb_build_bwd_data_instances
|
||||
test_ckb_testing_utils
|
||||
# test_ckb_factory_grouped_convolution_forward_convscale
|
||||
# test_ckb_factory_grouped_convolution_forward_scaleadd_ab
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x16)
|
||||
.with_bwd_specialization(cku::ConvSpecialization::DEFAULT)
|
||||
.with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1)
|
||||
.with_dl_thread_cluster(cku::DlThreadCluster_8x2)
|
||||
.with_dl_transfer(cku::DlTransfer5D);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DBf16_DL, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Dl",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough"});
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 3,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNDHWC,GKZYXC,GNDHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16>"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16>"}); // check compute types
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NGCHW,GKYXC,NGKHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"fp16,fp16,2,2>"}); // Check compute types and transpose params.
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_transpose_params(2, 4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave,v2", // pipeline versions
|
||||
"bf16,bf16,2,4>"}); // compute types and transpose params
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCDHW}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT)
|
||||
.with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NGCDHW,GKZYXC,NGKDHW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"v1"});
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v1_intrawave)
|
||||
.with_transpose_params(4, 4);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
"bf16,bf16,4,4>"});
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = ckb::TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{}
|
||||
.with_thread_block(cku::ThreadBlock_256_128x128x8)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x64x1)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::DEFAULT)
|
||||
.with_transpose_params(2, 2);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"fp16,fp16,2,2>"}); // check compute types and transpose params
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(cku::ThreadBlock_64_32x32x32)
|
||||
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
|
||||
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
|
||||
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
@@ -30,11 +30,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -22,18 +22,20 @@ TEST(FwdConvInstances,
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = I8,
|
||||
.accumulation_data_type = INT32,
|
||||
.accumulation_data_type = I32,
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x32x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
.with_thread_block(ThreadBlock_128_64x64x64)
|
||||
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x32x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_gridwise_gemm_pipeline(PipelineVersion::V1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
@@ -64,10 +64,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_3x3,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -32,11 +32,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_64_64x32x32)
|
||||
.with_thread_block(ThreadBlock_64_64x32x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
.with_dl_transfer(DlTransfer4D);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
@@ -59,16 +60,17 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x16)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_thread_block(ThreadBlock_256_128x128x16)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
|
||||
.with_dl_thread_cluster(DlThreadCluster_8x2)
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
.with_dl_transfer(DlTransfer4D);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl;
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
|
||||
@@ -25,11 +25,11 @@ constexpr auto SIGNATURE =
|
||||
.output = {.config = {.layout = ckb::TensorLayout::GNHWK}}};
|
||||
|
||||
constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(cku::FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(cku::ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(cku::FwdTransfer_4x64x1)
|
||||
.with_specializations(ckb::ConvFwdSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_transfer(cku::Transfer_4x64x1)
|
||||
.with_fwd_specializations(ckb::ConvSpecialization::DEFAULT,
|
||||
ckb::GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(cku::BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
|
||||
@@ -26,11 +26,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,12 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1_fp8)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
.with_transfer(Transfer_4x64x1_fp8)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -25,14 +25,13 @@ TEST(FwdConvInstances,
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
|
||||
.with_thread_block(ThreadBlock_256_256x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -62,14 +61,14 @@ TEST(
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
.base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
.with_thread_block(FwdThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x16x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{}
|
||||
.with_thread_block(ThreadBlock_128_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x16x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_num_conv_groups_to_merge(1);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_128x128x32)
|
||||
.with_thread_block(ThreadBlock_256_128x128x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -27,11 +27,11 @@ TEST(FwdConvInstances,
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
.with_thread_block(FwdThreadBlock_256_256x256x32)
|
||||
.with_thread_block(ThreadBlock_256_256x256x32)
|
||||
.with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave)
|
||||
.with_transfer(FwdTransfer_4x64x1)
|
||||
.with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_transfer(Transfer_4x64x1)
|
||||
.with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0,
|
||||
GemmSpecialization::MNKPadding)
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
@@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
@@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
|
||||
// Verify algorithm information
|
||||
EXPECT_EQ(Traits::thread_block_size, 256);
|
||||
|
||||
@@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
|
||||
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
|
||||
}
|
||||
|
||||
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
|
||||
@@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
|
||||
|
||||
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
|
||||
|
||||
EXPECT_EQ(Traits::conv_specialization,
|
||||
ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0);
|
||||
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdDataConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdDataConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdDataConvSignature, BwdDataConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_data",
|
||||
"fp16",
|
||||
|
||||
@@ -8,26 +8,27 @@ namespace {
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
constexpr ConvSignature BwdWeightConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
constexpr auto BwdWeightConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
|
||||
@@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(FwdTileTransfer_4x4x4)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
|
||||
@@ -28,18 +28,31 @@ struct ThreadBlock
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseXdlGemm
|
||||
struct XdlParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<XdlParams>);
|
||||
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseFwdXdlGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
static_assert(ckb::GridwiseFwdXdlGemmDescriptor<GridwiseFwdXdlGemm>);
|
||||
|
||||
struct GridwiseBwdXdlGemm
|
||||
{
|
||||
size_t k1 = 0;
|
||||
XdlParams xdl_params;
|
||||
};
|
||||
static_assert(ckb::GridwiseBwdXdlGemmDescriptor<GridwiseBwdXdlGemm>);
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
@@ -49,25 +62,36 @@ struct GridwiseWmmaGemm
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
|
||||
struct BlockGemm
|
||||
struct BlockGemmPipeline
|
||||
{
|
||||
PipelineVersion pipeline_version;
|
||||
PipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
static_assert(ckb::BlockGemmPipelineDescriptor<BlockGemmPipeline>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
size_t k_batch_size;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Specialization for ThreadSliceLength == 3
|
||||
template <>
|
||||
struct BlockTransfer<3>
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<3>, 3>);
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer<4>, 4>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
@@ -97,31 +121,35 @@ struct Epilogue
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
std::array<size_t, ThreadSliceLength> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<>>);
|
||||
static_assert(AccessOrderDescriptor<AccessOrder<4>>);
|
||||
|
||||
struct TransferAB
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct InputTransfer
|
||||
{
|
||||
BlockTransfer block_transfer;
|
||||
BlockTransfer<ThreadSliceLength> block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
AccessOrder block_transfer_access_order;
|
||||
AccessOrder src_access_order;
|
||||
AccessOrder<ThreadSliceLength> block_transfer_access_order;
|
||||
AccessOrder<ThreadSliceLength> src_access_order;
|
||||
};
|
||||
|
||||
struct TransferC
|
||||
struct OutputTransfer
|
||||
{
|
||||
ThreadCluster thread_cluster_dims;
|
||||
Epilogue epilogue;
|
||||
};
|
||||
|
||||
struct TransferABC
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer
|
||||
{
|
||||
TransferAB a;
|
||||
TransferAB b;
|
||||
TransferC c;
|
||||
InputTransfer<ThreadSliceLength> a;
|
||||
InputTransfer<ThreadSliceLength> b;
|
||||
OutputTransfer c;
|
||||
};
|
||||
|
||||
// DL-specific descriptors
|
||||
@@ -142,17 +170,19 @@ struct DlThreadCluster
|
||||
};
|
||||
static_assert(ckb::DlThreadClusterDescriptor<DlThreadCluster>);
|
||||
|
||||
template <size_t D = 4>
|
||||
struct DlBlockTransfer
|
||||
{
|
||||
std::array<size_t, 4> thread_slice_lengths;
|
||||
std::array<size_t, 4> thread_cluster_lengths;
|
||||
std::array<size_t, 4> thread_cluster_arrange_order;
|
||||
std::array<size_t, 4> src_access_order;
|
||||
std::array<size_t, 4> src_vector_tensor_lengths;
|
||||
std::array<size_t, 4> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, 4> dst_vector_tensor_lengths;
|
||||
std::array<size_t, D> thread_slice_lengths;
|
||||
std::array<size_t, D> thread_cluster_lengths;
|
||||
std::array<size_t, D> thread_cluster_arrange_order;
|
||||
std::array<size_t, D> src_access_order;
|
||||
std::array<size_t, D> src_vector_tensor_lengths;
|
||||
std::array<size_t, D> src_vector_tensor_contiguous_dim_order;
|
||||
std::array<size_t, D> dst_vector_tensor_lengths;
|
||||
};
|
||||
static_assert(ckb::DlBlockTransferDescriptor<DlBlockTransfer>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor4D<DlBlockTransfer<4>>);
|
||||
static_assert(ckb::DlBlockTransferDescriptor5D<DlBlockTransfer<5>>);
|
||||
|
||||
struct DlEpilogue
|
||||
{
|
||||
@@ -169,9 +199,14 @@ struct ThreadBlock_
|
||||
ThreadBlock thread_block;
|
||||
};
|
||||
|
||||
struct XdlGemm_
|
||||
struct FwdXdlGemm_
|
||||
{
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
GridwiseFwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BwdXdlGemm_
|
||||
{
|
||||
GridwiseBwdXdlGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct WmmaGemm_
|
||||
@@ -179,27 +214,48 @@ struct WmmaGemm_
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
template <size_t ThreadSliceLength = 3>
|
||||
struct Transfer_
|
||||
{
|
||||
TransferABC transfer;
|
||||
Transfer<ThreadSliceLength> transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
struct ConvSpecializationFwd_
|
||||
{
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
ConvSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
};
|
||||
|
||||
struct ConvSpecializationBwdWeight_
|
||||
{
|
||||
ConvSpecialization bwd_weight_specialization;
|
||||
};
|
||||
|
||||
struct Prefetch_
|
||||
{
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
PipelineScheduler loop_scheduler;
|
||||
};
|
||||
|
||||
struct TransposeParams_
|
||||
{
|
||||
size_t max_transpose_transfer_src_scalar_per_vector{1};
|
||||
size_t max_transpose_transfer_dst_scalar_per_vector{1};
|
||||
};
|
||||
|
||||
struct GemmBatchOptions_
|
||||
{
|
||||
size_t num_conv_groups_to_merge{1};
|
||||
};
|
||||
|
||||
struct BlockGemm_
|
||||
{
|
||||
BlockGemm block_gemm;
|
||||
BlockGemmPipeline block_gemm_pipeline;
|
||||
};
|
||||
|
||||
struct GridGemm_
|
||||
{
|
||||
PipelineVersion pipeline_version;
|
||||
};
|
||||
|
||||
struct DlThreadConfig_
|
||||
@@ -212,33 +268,34 @@ struct DlThreadCluster_
|
||||
DlThreadCluster thread_cluster;
|
||||
};
|
||||
|
||||
struct DlBlockTransferAB
|
||||
template <size_t Dim = 4>
|
||||
struct DlTransfer
|
||||
{
|
||||
DlBlockTransfer block_transfer;
|
||||
};
|
||||
|
||||
struct DlBlockTransferC
|
||||
{
|
||||
DlEpilogue epilogue;
|
||||
};
|
||||
|
||||
struct DlTransferABC
|
||||
{
|
||||
DlBlockTransferAB a;
|
||||
DlBlockTransferAB b;
|
||||
DlBlockTransferC c;
|
||||
DlBlockTransfer<Dim> a;
|
||||
DlBlockTransfer<Dim> b;
|
||||
DlEpilogue c;
|
||||
};
|
||||
|
||||
template <size_t Dim = 4>
|
||||
struct DlTransfer_
|
||||
{
|
||||
DlTransferABC transfer;
|
||||
DlTransfer<Dim> transfer;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
template <typename BaseAlgorithm>
|
||||
struct LargeTensorWrapper
|
||||
struct TwoStageSpecialization_
|
||||
{
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::TWO_STAGE;
|
||||
};
|
||||
|
||||
struct MultipleDSpecialization_
|
||||
{
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::MULTIPLE_D;
|
||||
};
|
||||
|
||||
struct LargeTensorSpecialization_
|
||||
{
|
||||
BaseAlgorithm base_algorithm;
|
||||
static constexpr ConvAlgorithmSpecialization specialization =
|
||||
ConvAlgorithmSpecialization::LARGE_TENSOR;
|
||||
};
|
||||
@@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components...
|
||||
constexpr auto with_gemm_config(const GemmConfig& gemm) const
|
||||
{
|
||||
auto result = *this;
|
||||
if constexpr(std::is_base_of_v<XdlGemm_, ConvAlgorithmTemplate>)
|
||||
if constexpr(std::is_base_of_v<FwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else if constexpr(std::is_base_of_v<BwdXdlGemm_, ConvAlgorithmTemplate>)
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
@@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components...
|
||||
{
|
||||
result.gridwise_gemm = gemm;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unrecognized GemmConfig type");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr auto with_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<Transfer_<3>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<Transfer_<4>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_specializations(ConvFwdSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec,
|
||||
GemmSpecialization gemm_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecialization_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<ConvSpecializationFwd_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.fwd_specialization = fwd_spec;
|
||||
result.gemm_specialization = gemm_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages,
|
||||
size_t groups_to_merge,
|
||||
PipelineScheduler scheduler) const
|
||||
constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<ConvSpecializationBwdWeight_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.bwd_weight_specialization = bwd_spec;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<Prefetch_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_gemm_k_prefetch_stages = k_prefetch_stages;
|
||||
result.num_groups_to_merge = groups_to_merge;
|
||||
result.loop_scheduler = scheduler;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_transpose_params(size_t max_src_scalar_per_vector,
|
||||
size_t max_dst_scalar_per_vector) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<TransposeParams_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector;
|
||||
result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GemmBatchOptions_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.num_conv_groups_to_merge = num_groups_to_merge;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BG>
|
||||
constexpr auto with_block_gemm(const BG& bg) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_gemm = bg;
|
||||
auto result = *this;
|
||||
result.block_gemm_pipeline = bg;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<GridGemm_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.pipeline_version = plv;
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components...
|
||||
template <typename T>
|
||||
constexpr auto with_dl_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlTransfer_, ConvAlgorithmTemplate>);
|
||||
static_assert(std::is_base_of_v<DlTransfer_<4>, ConvAlgorithmTemplate> ||
|
||||
std::is_base_of_v<DlTransfer_<5>, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
@@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components...
|
||||
}
|
||||
};
|
||||
|
||||
// Algorithm types
|
||||
// Fwd algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
GridGemm_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
ConvSpecializationFwd_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlTransfer_>;
|
||||
DlTransfer_<>>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
FwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationFwd_,
|
||||
Prefetch_,
|
||||
GemmBatchOptions_,
|
||||
LargeTensorSpecialization_>;
|
||||
|
||||
// CK Tile algorithm
|
||||
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
|
||||
TileBlockGemm_,
|
||||
TileTransfer_,
|
||||
@@ -488,4 +609,77 @@ struct ConvAlgorithm_Reference
|
||||
// GPU reference uses simple algorithm, no tile configuration needed
|
||||
};
|
||||
|
||||
// Bwd weight algorithm types
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<4>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_,
|
||||
GemmBatchOptions_,
|
||||
TwoStageSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlTransfer_<5>,
|
||||
ConvSpecializationBwdWeight_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
BwdXdlGemm_,
|
||||
Transfer_<4>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
TransposeParams_,
|
||||
GemmBatchOptions_,
|
||||
TwoStageSpecialization_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
GridGemm_,
|
||||
Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
WmmaGemm_,
|
||||
Transfer_<>,
|
||||
ConvSpecializationBwdWeight_,
|
||||
BlockGemm_,
|
||||
MultipleDSpecialization_>;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -120,14 +120,12 @@ struct DefaultAlgorithm
|
||||
ckb::test::ThreadBlock thread_block{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 16,
|
||||
.n_per_xdl = 16,
|
||||
.m_xdl_per_wave = 8,
|
||||
.n_xdl_per_wave = 8};
|
||||
ckb::test::GridwiseFwdXdlGemm gridwise_gemm{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}};
|
||||
|
||||
ckb::test::TransferABC transfer{
|
||||
ckb::test::Transfer<> transfer{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2},
|
||||
@@ -161,10 +159,11 @@ struct DefaultAlgorithm
|
||||
},
|
||||
};
|
||||
|
||||
ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler = ckb::PipelineScheduler::INTRAWAVE};
|
||||
ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT;
|
||||
ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default;
|
||||
ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4,
|
||||
.scheduler =
|
||||
ckb::PipelineScheduler::INTRAWAVE};
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<DefaultAlgorithm>);
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <vector>
|
||||
@@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Eq;
|
||||
using ::testing::NotNull;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
@@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
|
||||
|
||||
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
|
||||
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
|
||||
static_assert(ckt::TensorReflectable<Inputs, SIGNATURE>);
|
||||
static_assert(ckt::TensorReflectable<Outputs, SIGNATURE>);
|
||||
|
||||
TEST(ConvFwdTesting, MakeDescriptors)
|
||||
{
|
||||
@@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc)
|
||||
EXPECT_THAT(inputs.get().weight, NotNull());
|
||||
EXPECT_THAT(outputs.get().output, NotNull());
|
||||
}
|
||||
|
||||
TEST(ConvFwdTesting, Validate)
|
||||
{
|
||||
auto a = alloc_outputs(ARGS);
|
||||
auto b = alloc_outputs(ARGS);
|
||||
|
||||
// Positive test
|
||||
{
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
ARGS,
|
||||
[&]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
EXPECT_THAT(report.get_errors().size(), Eq(0));
|
||||
}
|
||||
|
||||
// Negative test
|
||||
{
|
||||
size_t field_count = 0;
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
ARGS,
|
||||
[&]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
++field_count;
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
EXPECT_THAT(report.get_errors().size(), Eq(field_count));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
.weight = {.config = {.layout = GKCX}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
.weight = {.config = {.layout = GKCYX}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
.weight = {.config = {.layout = GKCZYX}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NDHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = GNDHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
@@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
@@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
@@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 2);
|
||||
using ExpectedType =
|
||||
@@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
|
||||
MockAuxiliaryTensorConfig{.layout = GC},
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 3);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K,
|
||||
@@ -349,7 +349,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
@@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3>;
|
||||
|
||||
EXPECT_EQ(AuxLayouts::Size, 1);
|
||||
using ExpectedType = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
@@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
|
||||
using ExpectedDsLayout =
|
||||
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
|
||||
@@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
|
||||
@@ -27,7 +27,7 @@ TEST(ConvTensorType, Exhaustive)
|
||||
case FP32: EXPECT_TRUE((check_same<FP32, float>)); break;
|
||||
case FP16: EXPECT_TRUE((check_same<FP16, ck::half_t>)); break;
|
||||
case BF16: EXPECT_TRUE((check_same<BF16, ck::bhalf_t>)); break;
|
||||
case INT32: EXPECT_TRUE((check_same<INT32, uint32_t>)); break;
|
||||
case I32: EXPECT_TRUE((check_same<I32, uint32_t>)); break;
|
||||
case FP8: EXPECT_TRUE((check_same<FP8, ck::f8_t>)); break;
|
||||
case I8: EXPECT_TRUE((check_same<I8, int8_t>)); break;
|
||||
case U8: EXPECT_TRUE((check_same<U8, uint8_t>)); break;
|
||||
|
||||
@@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams)
|
||||
{
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3;
|
||||
ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE;
|
||||
} block_gemm;
|
||||
} block_gemm_pipeline;
|
||||
} kAlgorithm;
|
||||
constexpr auto block_gemm = SetBlockGemm<kAlgorithm>();
|
||||
|
||||
@@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion)
|
||||
{
|
||||
constexpr struct Algorithm
|
||||
{
|
||||
struct GridwiseGemm
|
||||
{
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
|
||||
} gridwise_gemm;
|
||||
ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4;
|
||||
} kAlgorithm;
|
||||
constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion<kAlgorithm>();
|
||||
|
||||
@@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization)
|
||||
{
|
||||
constexpr struct Algorithm
|
||||
{
|
||||
ckb::ConvFwdSpecialization fwd_specialization =
|
||||
ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
ckb::ConvSpecialization fwd_specialization =
|
||||
ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0;
|
||||
} kAlgorithm;
|
||||
constexpr auto conv_spec = SetFwdConvSpecialization<kAlgorithm>();
|
||||
|
||||
|
||||
464
experimental/builder/test/unit_debug.cpp
Normal file
464
experimental/builder/test/unit_debug.cpp
Normal file
@@ -0,0 +1,464 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include "ck_tile/builder/testing/debug.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ck_tile::test::StringEqWithDiff;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Gt;
|
||||
|
||||
TEST(Debug, PrintDescriptor)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{10, 11, 12}, ckt::PackedRightLayout{});
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_descriptor("test", desc, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Descriptor \"test\":\n"
|
||||
" data type: I32\n"
|
||||
" size: 1'320 elements\n"
|
||||
" space: 1'320 elements (5'280 bytes)\n"
|
||||
" lengths: [10, 11, 12]\n"
|
||||
" strides: [132, 12, 1]\n"
|
||||
" packed: yes\n"));
|
||||
|
||||
// Make sure that the stream locale does not leak.
|
||||
ss.str("");
|
||||
ss << 1000;
|
||||
EXPECT_THAT(ss.str(), StringEqWithDiff("1000"));
|
||||
}
|
||||
|
||||
TEST(Debug, LimitedForeach)
|
||||
{
|
||||
{
|
||||
std::vector<size_t> values;
|
||||
size_t delim_count = 0;
|
||||
ckt::detail::limited_foreach(
|
||||
10,
|
||||
2,
|
||||
[&](auto i) { values.push_back(i); },
|
||||
[&](auto skip_count) {
|
||||
++delim_count;
|
||||
EXPECT_THAT(skip_count, Eq(10 - 2));
|
||||
});
|
||||
EXPECT_THAT(values, ElementsAreArray({0, 9}));
|
||||
EXPECT_THAT(delim_count, Eq(1));
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<size_t> values;
|
||||
size_t delim_count = 0;
|
||||
ckt::detail::limited_foreach(
|
||||
100,
|
||||
9,
|
||||
[&](auto i) { values.push_back(i); },
|
||||
[&](auto skip_count) {
|
||||
++delim_count;
|
||||
EXPECT_THAT(skip_count, Eq(100 - 9));
|
||||
});
|
||||
EXPECT_THAT(values, ElementsAreArray({0, 1, 2, 3, 4, 96, 97, 98, 99}));
|
||||
EXPECT_THAT(delim_count, Eq(1));
|
||||
}
|
||||
|
||||
{
|
||||
size_t call_count = 0;
|
||||
size_t delim_count = 0;
|
||||
ckt::detail::limited_foreach(
|
||||
50,
|
||||
100,
|
||||
[&](auto i) {
|
||||
EXPECT_THAT(i, Eq(call_count));
|
||||
++call_count;
|
||||
},
|
||||
[&]([[maybe_unused]] auto skip_count) { ++delim_count; });
|
||||
EXPECT_THAT(call_count, Eq(50));
|
||||
EXPECT_THAT(delim_count, Eq(0));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensor0D)
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 123; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("0D", desc, a.get(), {}, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"0D\": shape = []\n"
|
||||
" 123\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensor1D)
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{44}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i % 7; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("1D", desc, a.get(), {}, ss);
|
||||
|
||||
// Note: output does not involve the size of the matrix separator fields,
|
||||
// since these are not printed.
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"1D\": shape = [44]\n"
|
||||
" 0 1 2 3 4 ... 4 5 6 0 1\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensor4D)
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{100, 110, 120, 130},
|
||||
ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("4D",
|
||||
desc,
|
||||
a.get(),
|
||||
{
|
||||
// Reduce default limits to have smaller output here.
|
||||
// That also tests that we can configure these (to some
|
||||
// extent).
|
||||
.col_limit = 4,
|
||||
.row_limit = 4,
|
||||
.slice_limit = 4,
|
||||
},
|
||||
ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"4D\": shape = [100, 110, 120, 130]\n"
|
||||
"Tensor \"4D\", slice [0, 0, :, :]\n"
|
||||
" 0 1 ... 128 129\n"
|
||||
" 130 131 ... 258 259\n"
|
||||
" ... ... ... ... ...\n"
|
||||
" 15340 15341 ... 15468 15469\n"
|
||||
" 15470 15471 ... 15598 15599\n"
|
||||
"\n"
|
||||
"Tensor \"4D\", slice [0, 1, :, :]\n"
|
||||
" 15600 15601 ... 15728 15729\n"
|
||||
" 15730 15731 ... 15858 15859\n"
|
||||
" ... ... ... ... ...\n"
|
||||
" 30940 30941 ... 31068 31069\n"
|
||||
" 31070 31071 ... 31198 31199\n"
|
||||
"\n"
|
||||
"(skipping 10'996 slices...)\n"
|
||||
"\n"
|
||||
"Tensor \"4D\", slice [99, 108, :, :]\n"
|
||||
" 171568800 171568801 ... 171568928 171568929\n"
|
||||
" 171568930 171568931 ... 171569058 171569059\n"
|
||||
" ... ... ... ... ...\n"
|
||||
" 171584140 171584141 ... 171584268 171584269\n"
|
||||
" 171584270 171584271 ... 171584398 171584399\n"
|
||||
"\n"
|
||||
"Tensor \"4D\", slice [99, 109, :, :]\n"
|
||||
" 171584400 171584401 ... 171584528 171584529\n"
|
||||
" 171584530 171584531 ... 171584658 171584659\n"
|
||||
" ... ... ... ... ...\n"
|
||||
" 171599740 171599741 ... 171599868 171599869\n"
|
||||
" 171599870 171599871 ... 171599998 171599999\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorCustomConfig)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{10, 10, 10}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 101 % 77; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("CustomConfig",
|
||||
desc,
|
||||
a.get(),
|
||||
{
|
||||
// Reduce default limits to have smaller output here.
|
||||
// That also tests that we can configure these.
|
||||
.col_limit = 4,
|
||||
.row_limit = 2,
|
||||
.slice_limit = 6,
|
||||
// Try with different sizes to make sure that the alignment
|
||||
// is still correct after changing these.
|
||||
.row_prefix = ">>>>",
|
||||
.row_field_sep = "|||||",
|
||||
.row_skip_val = "-------",
|
||||
.matrix_row_skip_val = "&&&&&&&&",
|
||||
},
|
||||
ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"CustomConfig\": shape = [10, 10, 10]\n"
|
||||
"Tensor \"CustomConfig\", slice [0, :, :]\n"
|
||||
">>>>||||| 0||||| 24|||||-------||||| 38||||| 62\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 4||||| 28|||||-------||||| 42||||| 66\n"
|
||||
"\n"
|
||||
"Tensor \"CustomConfig\", slice [1, :, :]\n"
|
||||
">>>>||||| 13||||| 37|||||-------||||| 51||||| 75\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 17||||| 41|||||-------||||| 55||||| 2\n"
|
||||
"\n"
|
||||
"Tensor \"CustomConfig\", slice [2, :, :]\n"
|
||||
">>>>||||| 26||||| 50|||||-------||||| 64||||| 11\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 30||||| 54|||||-------||||| 68||||| 15\n"
|
||||
"\n"
|
||||
"(skipping 4 slices...)\n"
|
||||
"\n"
|
||||
"Tensor \"CustomConfig\", slice [7, :, :]\n"
|
||||
">>>>||||| 14||||| 38|||||-------||||| 52||||| 76\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 18||||| 42|||||-------||||| 56||||| 3\n"
|
||||
"\n"
|
||||
"Tensor \"CustomConfig\", slice [8, :, :]\n"
|
||||
">>>>||||| 27||||| 51|||||-------||||| 65||||| 12\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 31||||| 55|||||-------||||| 69||||| 16\n"
|
||||
"\n"
|
||||
"Tensor \"CustomConfig\", slice [9, :, :]\n"
|
||||
">>>>||||| 40||||| 64|||||-------||||| 1||||| 25\n"
|
||||
">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n"
|
||||
">>>>||||| 44||||| 68|||||-------||||| 5||||| 29\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorUnlimitedMatrix)
|
||||
{
|
||||
// To limit the output of the test, split the "unlimited" test up into one for the
|
||||
// matrices and one for the slices.
|
||||
|
||||
const ckt::Extent shape = ckt::Extent{12, 12};
|
||||
const ckt::TensorPrintConfig default_config;
|
||||
|
||||
// The shape should be larger than the default, otherwise this test doesn't make
|
||||
// any sense.
|
||||
ASSERT_THAT(shape[1], Gt(default_config.col_limit));
|
||||
ASSERT_THAT(shape[2], Gt(default_config.row_limit));
|
||||
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i ^ 0xF; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"UnlimitedConfig\": shape = [12, 12]\n"
|
||||
" 15 14 13 12 11 10 9 8 7 6 5 4\n"
|
||||
" 3 2 1 0 31 30 29 28 27 26 25 24\n"
|
||||
" 23 22 21 20 19 18 17 16 47 46 45 44\n"
|
||||
" 43 42 41 40 39 38 37 36 35 34 33 32\n"
|
||||
" 63 62 61 60 59 58 57 56 55 54 53 52\n"
|
||||
" 51 50 49 48 79 78 77 76 75 74 73 72\n"
|
||||
" 71 70 69 68 67 66 65 64 95 94 93 92\n"
|
||||
" 91 90 89 88 87 86 85 84 83 82 81 80\n"
|
||||
" 111 110 109 108 107 106 105 104 103 102 101 100\n"
|
||||
" 99 98 97 96 127 126 125 124 123 122 121 120\n"
|
||||
" 119 118 117 116 115 114 113 112 143 142 141 140\n"
|
||||
" 139 138 137 136 135 134 133 132 131 130 129 128\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorUnlimitedSlices)
|
||||
{
|
||||
// To limit the output of the test, split the "unlimited" test up into one for the
|
||||
// matrices and one for the slices.
|
||||
|
||||
const ckt::Extent shape = ckt::Extent{13, 1, 1};
|
||||
const ckt::TensorPrintConfig default_config;
|
||||
|
||||
// The shape should be larger than the default, otherwise this test doesn't make
|
||||
// any sense.
|
||||
ASSERT_THAT(shape[0], Gt(default_config.slice_limit));
|
||||
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 3; });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"UnlimitedConfig\": shape = [13, 1, 1]\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [0, :, :]\n"
|
||||
" 0\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [1, :, :]\n"
|
||||
" 3\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [2, :, :]\n"
|
||||
" 6\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [3, :, :]\n"
|
||||
" 9\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [4, :, :]\n"
|
||||
" 12\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [5, :, :]\n"
|
||||
" 15\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [6, :, :]\n"
|
||||
" 18\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [7, :, :]\n"
|
||||
" 21\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [8, :, :]\n"
|
||||
" 24\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [9, :, :]\n"
|
||||
" 27\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [10, :, :]\n"
|
||||
" 30\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [11, :, :]\n"
|
||||
" 33\n"
|
||||
"\n"
|
||||
"Tensor \"UnlimitedConfig\", slice [12, :, :]\n"
|
||||
" 36\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorFP32)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(1.9999, i); });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("FP32", desc, a.get(), {}, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"FP32\": shape = [5, 5]\n"
|
||||
" 1.000 2.000 4.000 7.999 15.997\n"
|
||||
" 31.992 63.981 127.955 255.898 511.770\n"
|
||||
" 1023.488 2046.874 4093.543 8186.677 16372.535\n"
|
||||
" 32743.432 65483.590 130960.633 261908.172 523790.156\n"
|
||||
" 1047527.938 2094951.125 4189692.750 8378966.500 16757095.000\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorBF16)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::BF16>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(
|
||||
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(1.2345678f * i); });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("BF16", desc, a.get(), {}, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"BF16\": shape = [5, 5]\n"
|
||||
" 0.000 1.234 2.469 3.703 4.938\n"
|
||||
" 6.188 7.406 8.625 9.875 11.125\n"
|
||||
" 12.375 13.562 14.812 16.000 17.250\n"
|
||||
" 18.500 19.750 21.000 22.250 23.500\n"
|
||||
" 24.750 25.875 27.125 28.375 29.625\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorFP8)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::FP8>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(
|
||||
desc, a.get(), [](size_t i) { return ck::type_convert<ck::f8_t>(i * 0.1f); });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("FP8", desc, a.get(), {}, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"FP8\": shape = [5, 5]\n"
|
||||
" 0.000 0.102 0.203 0.312 0.406\n"
|
||||
" 0.500 0.625 0.688 0.812 0.875\n"
|
||||
" 1.000 1.125 1.250 1.250 1.375\n"
|
||||
" 1.500 1.625 1.750 1.750 1.875\n"
|
||||
" 2.000 2.000 2.250 2.250 2.500\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorSpecialFloats)
|
||||
{
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5, 5}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) {
|
||||
if(i % 8 == 1)
|
||||
return 0.f / 0.f;
|
||||
else if(i % 7 == 1)
|
||||
return std::sqrt(-1.f);
|
||||
else if(i % 6 == 1)
|
||||
return 1.f / 0.f;
|
||||
else if(i % 5 == 1)
|
||||
return -1.f / 0.f;
|
||||
else
|
||||
return static_cast<float>(i);
|
||||
});
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("specials", desc, a.get(), {}, ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"specials\": shape = [5, 5]\n"
|
||||
" 0.000 nan 2.000 3.000 4.000\n"
|
||||
" 5.000 -inf inf -nan nan\n"
|
||||
" 10.000 -inf 12.000 inf 14.000\n"
|
||||
" -nan -inf nan 18.000 inf\n"
|
||||
" 20.000 -inf -nan 23.000 24.000\n"));
|
||||
}
|
||||
|
||||
TEST(Debug, PrintTensorFloatPrecision)
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::FP32>(ckt::Extent{5}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(0.9, i); });
|
||||
|
||||
std::stringstream ss;
|
||||
ckt::print_tensor("FloatPrecision",
|
||||
desc,
|
||||
a.get(),
|
||||
{
|
||||
.float_precision = 10,
|
||||
},
|
||||
ss);
|
||||
|
||||
EXPECT_THAT(ss.str(),
|
||||
StringEqWithDiff( //
|
||||
"Tensor \"FloatPrecision\": shape = [5]\n"
|
||||
" 1.0000000000 0.8999999762 0.8100000024 0.7289999723 0.6560999751\n"));
|
||||
}
|
||||
@@ -88,3 +88,11 @@ TEST(DeviceBuffer, AllocTensorBuffer)
|
||||
EXPECT_THAT(hipMemset(buffer.get(), 0xFF, descriptor.get_element_space_size_in_bytes()),
|
||||
HipSuccess());
|
||||
}
|
||||
|
||||
TEST(DeviceBuffer, AlignForward)
|
||||
{
|
||||
EXPECT_THAT(ckt::align_fwd(24, 8), Eq(24));
|
||||
EXPECT_THAT(ckt::align_fwd(25, 8), Eq(32));
|
||||
EXPECT_THAT(ckt::align_fwd(0xd7c563, 0x1000), Eq(0xd7d000));
|
||||
EXPECT_THAT(ckt::align_fwd(19561, 23), Eq(19573));
|
||||
}
|
||||
|
||||
@@ -6,11 +6,13 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <array>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ck_tile::test::StringEqWithDiff;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Throws;
|
||||
@@ -76,7 +78,7 @@ TEST(TensorDescriptor, MakeDescriptor)
|
||||
|
||||
// Note: automatic inference of RANK.
|
||||
const auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::INT32>(lengths, ckt::PackedRightLayout{});
|
||||
ckt::make_descriptor<ckb::DataType::I32>(lengths, ckt::PackedRightLayout{});
|
||||
|
||||
EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths));
|
||||
EXPECT_THAT(desc.get_strides(),
|
||||
@@ -173,7 +175,7 @@ TEST(TensorDescriptor, ExtentFromVector)
|
||||
|
||||
TEST(TensorDescriptor, IsPacked)
|
||||
{
|
||||
constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test
|
||||
constexpr auto dt = ckb::DataType::I32; // Irrelevant for this test
|
||||
EXPECT_TRUE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{})
|
||||
.is_packed());
|
||||
@@ -189,3 +191,20 @@ TEST(TensorDescriptor, IsPacked)
|
||||
EXPECT_FALSE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed());
|
||||
}
|
||||
|
||||
TEST(TensorDescriptor, PrintExtent)
|
||||
{
|
||||
{
|
||||
const ckt::Extent extent{6233, 55, 1235, 52, 203};
|
||||
std::stringstream ss;
|
||||
ss << extent;
|
||||
EXPECT_THAT(ss.str(), StringEqWithDiff("[6233, 55, 1235, 52, 203]"));
|
||||
}
|
||||
|
||||
{
|
||||
const ckt::Extent extent{};
|
||||
std::stringstream ss;
|
||||
ss << extent;
|
||||
EXPECT_THAT(ss.str(), StringEqWithDiff("[]"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,28 @@ namespace ckt = ck_tile::builder::test;
|
||||
using ::testing::Each;
|
||||
using ::testing::Eq;
|
||||
|
||||
TEST(TensorForeach, NdIter)
|
||||
{
|
||||
{
|
||||
ckt::NdIter iter(ckt::Extent{523, 345, 123, 601});
|
||||
|
||||
EXPECT_THAT(iter.numel(), Eq(13'338'296'505ULL));
|
||||
EXPECT_THAT(iter(0), Eq(ckt::Extent{0, 0, 0, 0}));
|
||||
EXPECT_THAT(iter(1), Eq(ckt::Extent{0, 0, 0, 1}));
|
||||
EXPECT_THAT(iter(601), Eq(ckt::Extent{0, 0, 1, 0}));
|
||||
EXPECT_THAT(iter(601 * 123), Eq(ckt::Extent{0, 1, 0, 0}));
|
||||
EXPECT_THAT(iter(601 * 123 * 10), Eq(ckt::Extent{0, 10, 0, 0}));
|
||||
EXPECT_THAT(iter(((34 * 345 + 63) * 123 + 70) * 601 + 5), Eq(ckt::Extent{34, 63, 70, 5}));
|
||||
}
|
||||
|
||||
{
|
||||
ckt::NdIter iter(ckt::Extent{});
|
||||
|
||||
EXPECT_THAT(iter.numel(), Eq(1));
|
||||
EXPECT_THAT(iter(0), Eq(ckt::Extent{}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorForeach, CalculateOffset)
|
||||
{
|
||||
EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123));
|
||||
@@ -87,8 +109,8 @@ TEST(TensorForeach, VisitsEveryIndex)
|
||||
|
||||
TEST(TensorForeach, FillTensorBuffer)
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(ckt::Extent{31, 54, 13},
|
||||
ckt::PackedRightLayout{});
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::I32>(ckt::Extent{31, 54, 13}, ckt::PackedRightLayout{});
|
||||
|
||||
auto buffer = ckt::alloc_tensor_buffer(desc);
|
||||
|
||||
@@ -109,7 +131,7 @@ TEST(TensorForeach, FillTensor)
|
||||
// FillTensor with non-packed indices should not write out-of-bounds.
|
||||
const ckt::Extent shape = {4, 23, 35};
|
||||
const ckt::Extent pad = {12, 53, 100};
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{}(pad));
|
||||
const auto strides = desc.get_strides();
|
||||
|
||||
auto size = desc.get_element_space_size();
|
||||
@@ -169,7 +191,7 @@ TEST(TensorForeach, ClearTensorZeros)
|
||||
const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6};
|
||||
|
||||
const auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::INT32>(shape, ckt::PackedRightLayout{}(pad));
|
||||
ckt::make_descriptor<ckb::DataType::I32>(shape, ckt::PackedRightLayout{}(pad));
|
||||
|
||||
auto buffer = ckt::alloc_tensor_buffer(desc);
|
||||
ckt::clear_tensor_buffer(desc, buffer.get());
|
||||
|
||||
@@ -173,8 +173,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
}
|
||||
|
||||
{
|
||||
auto desc = ckt::make_descriptor<ckb::DataType::INT32, 3>({'G', 'P', 'U'},
|
||||
ckt::PackedRightLayout{});
|
||||
auto desc =
|
||||
ckt::make_descriptor<ckb::DataType::I32, 3>({'G', 'P', 'U'}, ckt::PackedRightLayout{});
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
auto b = ckt::alloc_tensor_buffer(desc);
|
||||
@@ -204,6 +204,7 @@ struct DummySignature
|
||||
constexpr DummySignature DUMMY_SIGNATURE = {};
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <>
|
||||
struct Args<DUMMY_SIGNATURE>
|
||||
{
|
||||
@@ -225,6 +226,7 @@ struct Outputs<DUMMY_SIGNATURE>
|
||||
void* b;
|
||||
};
|
||||
|
||||
// Explicitly implement validate for this type to test that that works.
|
||||
template <>
|
||||
ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
|
||||
Outputs<DUMMY_SIGNATURE> actual,
|
||||
|
||||
@@ -15,31 +15,42 @@ using namespace test;
|
||||
constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{
|
||||
.k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1};
|
||||
|
||||
constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}};
|
||||
|
||||
constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{
|
||||
.thread_slice_lengths = {8, 1, 1, 2},
|
||||
.thread_cluster_lengths = {2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {1, 2, 0, 3},
|
||||
.src_access_order = {1, 2, 0, 3},
|
||||
.src_vector_tensor_lengths = {4, 1, 1, 2},
|
||||
.src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 2}};
|
||||
|
||||
constexpr DlTransferABC DlFwdTransfer{.a =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = DlBlockTransferAB,
|
||||
},
|
||||
.c = {
|
||||
.epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4},
|
||||
}};
|
||||
constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2,
|
||||
.b = DlBlockTransfer_8x1x1x2,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 4}};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1{
|
||||
constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{
|
||||
.thread_slice_lengths = {1, 8, 1, 1, 1},
|
||||
.thread_cluster_lengths = {1, 2, 1, 128, 1},
|
||||
.thread_cluster_arrange_order = {0, 2, 3, 1, 4},
|
||||
.src_access_order = {0, 2, 3, 1, 4},
|
||||
.src_vector_tensor_lengths = {1, 1, 1, 1, 1},
|
||||
.src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4},
|
||||
.dst_vector_tensor_lengths = {1, 1, 1, 1, 1}};
|
||||
|
||||
constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1,
|
||||
.b = DlBlockTransfer_1x8x1x1x1,
|
||||
.c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5},
|
||||
.src_dst_vector_dim = 5,
|
||||
.dst_scalar_per_vector = 1}};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
constexpr Transfer<4> BwdTransfer_4x64x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1},
|
||||
.lds_transfer = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 4,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = true},
|
||||
.block_transfer_access_order = {0, 3, 1, 2},
|
||||
.src_access_order = {0, 2, 1, 3},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.b =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
.lds_transfer = {.src_vector_dim = 1,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 2,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.block_transfer_access_order = {2, 0, 1},
|
||||
.src_access_order = {1, 0, 2},
|
||||
},
|
||||
.c =
|
||||
{
|
||||
.thread_cluster_dims =
|
||||
{.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
.epilogue = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 2},
|
||||
},
|
||||
};
|
||||
|
||||
constexpr Transfer<> Transfer_4x64x1_fp8{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
@@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x16x1{
|
||||
constexpr Transfer<> Transfer_4x16x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1},
|
||||
@@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr TransferABC FwdTransfer_4x32x1{
|
||||
constexpr Transfer<> Transfer_4x32x1{
|
||||
.a =
|
||||
{
|
||||
.block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1},
|
||||
@@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{
|
||||
},
|
||||
};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{
|
||||
.k1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
|
||||
constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8,
|
||||
.m_per_wmma = 32,
|
||||
.n_per_wmma = 32,
|
||||
.m_wmma_per_wave = 2,
|
||||
.n_wmma_per_wave = 1,
|
||||
.pipeline_version = PipelineVersion::V1};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{
|
||||
.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{
|
||||
.k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 8}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64,
|
||||
.tile_size = {.m = 32, .n = 32, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5,
|
||||
.scheduler = PipelineScheduler::INTRAWAVE};
|
||||
constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128,
|
||||
.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = {
|
||||
.pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
|
||||
@@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils {
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_1x1x1{
|
||||
constexpr TileTransfer TileTransfer_1x1x1{
|
||||
.a_scalar_per_vector = 1,
|
||||
.b_scalar_per_vector = 1,
|
||||
.c_scalar_per_vector = 1,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_4x4x4{
|
||||
constexpr TileTransfer TileTransfer_4x4x4{
|
||||
.a_scalar_per_vector = 4,
|
||||
.b_scalar_per_vector = 4,
|
||||
.c_scalar_per_vector = 4,
|
||||
};
|
||||
|
||||
constexpr TileTransfer FwdTileTransfer_8x8x8{
|
||||
constexpr TileTransfer TileTransfer_8x8x8{
|
||||
.a_scalar_per_vector = 8,
|
||||
.b_scalar_per_vector = 8,
|
||||
.c_scalar_per_vector = 8,
|
||||
};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
|
||||
|
||||
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
|
||||
|
||||
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
|
||||
.warps = {.m = 2, .n = 2, .k = 1},
|
||||
|
||||
@@ -54,7 +54,7 @@ inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
|
||||
inline std::string to_string<ConvSpecialization>(ConvSpecialization t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
@@ -86,11 +86,20 @@ inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
|
||||
inline std::string to_string<GridwiseBwdXdlGemm>(GridwiseBwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
|
||||
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
|
||||
oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << ","
|
||||
<< t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseFwdXdlGemm>(GridwiseFwdXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl
|
||||
<< "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -104,17 +113,29 @@ inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
inline std::string to_string<BlockGemmPipeline>(BlockGemmPipeline t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockTransfer>(BlockTransfer t)
|
||||
template <size_t ThreadClusterRank>
|
||||
inline std::string to_string(BlockTransfer<ThreadClusterRank> t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
if constexpr(ThreadClusterRank == 4)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 4>{t.k_batch_size, t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else if constexpr(ThreadClusterRank == 3)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4,
|
||||
"Unsupported ThreadClusterRank");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -134,14 +155,14 @@ inline std::string to_string<LdsTransfer>(LdsTransfer t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<AccessOrder>(AccessOrder t)
|
||||
template <size_t N>
|
||||
inline std::string to_string(AccessOrder<N> t)
|
||||
{
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferAB>(TransferAB t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(InputTransfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
@@ -152,7 +173,7 @@ inline std::string to_string<TransferAB>(TransferAB t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferC>(TransferC t)
|
||||
inline std::string to_string<OutputTransfer>(OutputTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
@@ -160,8 +181,8 @@ inline std::string to_string<TransferC>(TransferC t)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferABC>(TransferABC t)
|
||||
template <size_t N = 3>
|
||||
inline std::string to_string(Transfer<N> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -185,7 +206,19 @@ inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
|
||||
inline std::string to_string<DlBlockTransfer<4>>(DlBlockTransfer<4> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
|
||||
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
|
||||
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
|
||||
<< array_to_seq(t.dst_vector_tensor_lengths);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer<5>>(DlBlockTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
@@ -206,19 +239,24 @@ inline std::string to_string<DlEpilogue>(DlEpilogue t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
|
||||
inline std::string to_string<TransposeParams_>(TransposeParams_ t)
|
||||
{
|
||||
return to_string(t.block_transfer);
|
||||
std::ostringstream oss;
|
||||
oss << t.max_transpose_transfer_src_scalar_per_vector << ","
|
||||
<< t.max_transpose_transfer_dst_scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
|
||||
inline std::string to_string<DlTransfer<4>>(DlTransfer<4> t)
|
||||
{
|
||||
return to_string(t.epilogue);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransferABC>(DlTransferABC t)
|
||||
inline std::string to_string<DlTransfer<5>>(DlTransfer<5> t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
@@ -234,7 +272,13 @@ inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
|
||||
inline std::string to_string<FwdXdlGemm_>(FwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BwdXdlGemm_>(BwdXdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
@@ -245,33 +289,40 @@ inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Transfer_>(Transfer_ t)
|
||||
template <size_t ThreadClusterRank = 3>
|
||||
inline std::string to_string(Transfer_<ThreadClusterRank> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
|
||||
inline std::string to_string<ConvSpecializationFwd_>(ConvSpecializationFwd_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecializationBwdWeight_>(ConvSpecializationBwdWeight_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.bwd_weight_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
|
||||
<< to_string(t.loop_scheduler);
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
|
||||
{
|
||||
return to_string(t.block_gemm);
|
||||
return to_string(t.block_gemm_pipeline);
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -287,7 +338,13 @@ inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
|
||||
inline std::string to_string<DlTransfer_<4>>(DlTransfer_<4> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_<5>>(DlTransfer_<5> t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
@@ -299,8 +356,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -309,8 +366,8 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_C
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -320,7 +377,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CS
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -332,7 +389,7 @@ inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_>(t));
|
||||
<< to_string(static_cast<DlTransfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
@@ -340,7 +397,102 @@ template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
|
||||
{
|
||||
return to_string(t.base_algorithm);
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<FwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_<5>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<BwdXdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_<4>>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
@@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
__host__ __device__ static bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t BlockSize>
|
||||
struct TileLoopKernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
static int GetCuBlocks()
|
||||
{
|
||||
int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config)
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = GetKernelOccupancy(kernel);
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks());
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int GetKernelOccupancy(const KernelFunction& kernel)
|
||||
{
|
||||
int occupancy = 0;
|
||||
ck::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
static int GetComputeUnitCount()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
ck::hip_check_error(hipGetDevice(&dev));
|
||||
ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user