diff --git a/CHANGELOG.md b/CHANGELOG.md index 3280ad07dc..83414adc82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added top-k sigmoid kernel in CK_TILE * Added the blockscale 2D support for CK_TILE GEMM. * Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types +* Added reduce and multi reduction kernels ### Changed diff --git a/Jenkinsfile b/Jenkinsfile index 7292d9b70c..9c670183fd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -574,6 +574,8 @@ def cmake_build(Map conf=[:]){ def setup_cmd def build_cmd def execute_cmd = conf.get("execute_cmd", "") + //check the node gpu architecture + def arch_name = check_arch_name() if(!setup_args.contains("NO_CK_BUILD")){ if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" @@ -646,15 +648,15 @@ def cmake_build(Map conf=[:]){ //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${check_arch_name()}.json" - archiveArtifacts "ck_build_trace_${check_arch_name()}.json" - sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${check_arch_name()}.json" + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace_${arch_name}.json" + archiveArtifacts "ck_build_trace_${arch_name}.json" + sh "python3 ../script/parse_ninja_trace.py ck_build_trace_${arch_name}.json" if (params.NINJA_BUILD_TRACE || params.BUILD_INSTANCES_ONLY){ if (params.NINJA_FTIME_TRACE) { echo "running ClangBuildAnalyzer" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${check_arch_name()}.log" - archiveArtifacts "clang_build_analysis_${check_arch_name()}.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis_${arch_name}.log" + archiveArtifacts "clang_build_analysis_${arch_name}.log" } @@ -672,8 +674,8 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" } } if(params.BUILD_INSTANCES_ONLY){ @@ -699,16 +701,14 @@ def cmake_build(Map conf=[:]){ if(params.BUILD_PACKAGES){ echo "Build ckProfiler packages" sh 'ninja -j64 package' - sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${check_arch_name()}.deb" - stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${check_arch_name()}" + sh "mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64_${arch_name}.deb" + stash includes: "composablekernel-ckprofiler**.deb", name: "profiler_package_${arch_name}" } } } } } - //check the node gpu architecture - def arch_name = check_arch_name() if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_*.log" @@ -1201,8 +1201,8 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_TILE_ENGINE_BASIC_TESTS", - defaultValue: false, - description: "Run the tile_engine_basic tests (default: OFF)") + defaultValue: true, + description: "Run the tile_engine_basic tests (default: ON)") booleanParam( name: "RUN_TILE_ENGINE_GEMM_TESTS", defaultValue: false, @@ -1650,7 +1650,10 @@ pipeline { -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -D GEMM_PRESHUFFLE_CONFIG_FILE="default_ci_config.json" .. && \ - ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all """ + ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ + python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) @@ -1667,37 +1670,6 @@ pipeline { } parallel { - stage("Run TILE_ENGINE_GEMM Tests on gfx90a") - { - when { - beforeAgent true - expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } - } - agent{ label rocmnode("gfx90a") } - environment{ - setup_args = "NO_CK_BUILD" - execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER="${params.BUILD_COMPILER}" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx90a" \ - -D GEMM_UNIVERSAL_DATATYPE="fp8;fp16" \ - -D GEMM_UNIVERSAL_LAYOUT="rcr;rrr;crr;ccr" \ - -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ - -D GEMM_STREAMK_LAYOUT="rcr" \ - -D GEMM_MULTI_D_DATATYPE="fp16" \ - -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ - -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ - -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j${nthreads()} benchmark_gemm_universal_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ - python3 ../tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ - python3 ../tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) - cleanWs() - } - } stage("Run TILE_ENGINE_GEMM Tests on gfx942") { when { diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index ce41c3310f..a7dae9dcd8 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) +add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp new file mode 100644 index 0000000000..bd58ea433f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_wmma_fp16.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDs = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>; +// clang-format on + +#include "run_grouped_gemm_multiple_d_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 0e1a38b19a..9fdcf4aaad 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -71,339 +71,6 @@ using DeviceGemmInstance = < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on -struct ProblemSize final -{ - std::vector Ms; - std::vector Ns; - std::vector Ks; +#include "run_grouped_gemm_multiple_d_example.inc" - std::vector stride_As; - std::vector stride_Bs; - std::vector> stride_Ds; - std::vector stride_Cs; - - ck::index_t group_count; -}; - -struct ExecutionConfig final -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; -}; - -bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - auto group_count = problem_size.group_count; - - using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; - using GemmDesc = ck::tensor_operation::device::GemmDesc; - - // GEMM shape - std::vector gemm_descs; - std::vector ggemm_kargs; - std::vector p_Cs; - std::vector p_As; - std::vector p_Bs; - std::vector> p_Ds = {}; - - gemm_descs.reserve(group_count); - ggemm_kargs.reserve(group_count); - p_As.reserve(group_count); - p_Bs.reserve(group_count); - p_Ds.reserve(group_count); - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - std::vector> a_tensors; - std::vector> b_tensors; - std::vector, NumDs>> d_tensors; - std::vector> c_host_tensors; - std::vector> c_device_result_tensors; - - a_tensors.reserve(group_count); - b_tensors.reserve(group_count); - d_tensors.reserve(group_count); - c_host_tensors.reserve(group_count); - c_device_result_tensors.reserve(group_count); - - using DeviceMemPtr = std::unique_ptr; - - std::vector a_tensors_device, b_tensors_device, c_tensors_device; - std::vector> d_tensors_device; - - a_tensors_device.reserve(group_count); - b_tensors_device.reserve(group_count); - c_tensors_device.reserve(group_count); - d_tensors_device.resize(group_count); // reserve and update vector size - - std::size_t flop = 0, num_btype = 0; - - for(int i = 0; i < group_count; i++) - { - a_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); - b_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); - - auto d0_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - auto d1_tensor = Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); - - std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; - d_tensors.push_back(d_tens); - c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( - problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); - std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc - << " b_k_n: " << b_tensors[i].mDesc - << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; - - flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; - num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + - sizeof(BDataType) * b_tensors[i].GetElementSize() + - sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + - sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); - - switch(config.init_method) - { - case 0: break; - case 1: - a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - } - break; - case 2: - a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - break; - default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); - for(int j = 0; j < NumDs; ++j) - { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); - } - } - } - - for(int i = 0; i < group_count; i++) - { - a_tensors_device.emplace_back( - std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); - b_tensors_device.emplace_back( - std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); - c_tensors_device.emplace_back(std::make_unique( - c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); - - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i].emplace_back(std::make_unique( - d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); - } - - a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); - b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); - for(int j = 0; j < NumDs; ++j) - { - d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); - } - c_tensors_device[i]->SetZero(); - - p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); - p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); - p_Ds.push_back( - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); - p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); - - // The device op does not have to know M problem size at lunch time. - gemm_descs.push_back({0, - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - problem_size.stride_Cs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); - ggemm_kargs.push_back( - {a_tensors_device[i]->GetDeviceBuffer(), - b_tensors_device[i]->GetDeviceBuffer(), - {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, - c_tensors_device[i]->GetDeviceBuffer(), - problem_size.Ms[i], - problem_size.Ns[i], - problem_size.Ks[i], - problem_size.stride_As[i], - problem_size.stride_Bs[i], - {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, - problem_size.stride_Cs[i]}); - } - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - - // do GEMM - auto argument = gemm.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), - ggemm_kargs.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); - - invoker.Run(argument, StreamConfig{nullptr, false, 1}); - - bool pass = true; - if(config.do_verification) - { - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmMultipleD; - - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - auto karg = ggemm_kargs[i]; - auto dev_res_tensor = - Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); - c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], - b_tensors[i], - d_tensors[i], - c_host_tensors[i], - a_element_op, - b_element_op, - cde_element_op); - - ref_invoker.Run(ref_argument); - pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); - } - - std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; - } - - if(config.time_kernel) - { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; - } - - return pass; -} - -std::vector argToIntArray(char* input) -{ - std::vector out; - std::istringstream in(input); - std::string item; - - while(std::getline(in, item, ',')) - { - out.push_back(std::stoi(item)); - } - return out; -} - -int main(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - if(argc < 10) - { - std::vector Ms{64, 127, 255, 129, 260, 190, 77}; - problem_size.group_count = Ms.size(); - - for(int i = 0; i < problem_size.group_count; i++) - { - problem_size.Ms.push_back(Ms[i]); - problem_size.Ns.push_back(252); - problem_size.Ks.push_back(4608); - - problem_size.stride_As.push_back(problem_size.Ks[i]); - problem_size.stride_Bs.push_back(problem_size.Ks[i]); - problem_size.stride_Cs.push_back(problem_size.Ns[i]); - - problem_size.stride_Ds.push_back({}); - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); - } - } - - std::cout - << "Usage:\n" - << "arg1: verification (0=no, 1=yes)\n" - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: time kernel (0=n0, 1=yes)\n" - << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n" - << "... setting default values." << std::endl; - } - else - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - - problem_size.Ms = argToIntArray(argv[4]); - problem_size.Ns = argToIntArray(argv[5]); - problem_size.Ks = argToIntArray(argv[6]); - - problem_size.stride_As = argToIntArray(argv[7]); - problem_size.stride_Bs = argToIntArray(argv[8]); - problem_size.stride_Cs = argToIntArray(argv[9]); - - for(int j = 0; j < NumDs; ++j) - { - problem_size.stride_Ds.push_back(problem_size.stride_Cs); - } - - problem_size.group_count = problem_size.Ms.size(); - } - - return !run_grouped_gemm(problem_size, config); -} +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp index e4da397c23..e942aad1c1 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp index d5b2205892..fb3a6f0b4f 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 764b533455..ffd0c5e9b7 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: async hargs (0=n0, 1=yes)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4: async hargs (0=no, 1=yes)\n"); printf("arg5: group count (default=16)\n"); #if defined(EXAMPLE_USE_SPLITK) printf("arg6: k-batch count (default=1)\n"); diff --git a/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc new file mode 100644 index 0000000000..a71a23ab79 --- /dev/null +++ b/example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc @@ -0,0 +1,341 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + + // GEMM shape + std::vector gemm_descs; + std::vector ggemm_kargs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + ggemm_kargs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDs>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + // The device op does not have to know M problem size at lunch time. + gemm_descs.push_back({0, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); + ggemm_kargs.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, + problem_size.stride_Cs[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + ggemm_kargs.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = ggemm_kargs[i]; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +bool run_grouped_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 10) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return run_grouped_gemm(problem_size, config); +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp index ae707e74a2..ccb3a9c435 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp @@ -119,7 +119,7 @@ static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_an static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight. #if 1 -static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t MPerBlock = 64; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); @@ -156,7 +156,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, + int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< @@ -171,7 +172,8 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, + int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>; #endif // clang-format on @@ -182,12 +184,14 @@ int main(int argc, char* argv[]) bool time_kernel = true; #if 1 // GEMM shape - ck::index_t N = 4096; - ck::index_t K = 6144; + ck::index_t N = 1536; + ck::index_t K = 4096; + // ck::index_t N = 4096; + // ck::index_t K = 6144; // ck::index_t N = 128; // ck::index_t K = 512; - ck::index_t experts = 8; - ck::index_t topk = 2; + ck::index_t experts = 16; + ck::index_t topk = 8; // ck::index_t sorted_tile_num = 515; // ck::index_t valid_tile_num = 512; // ck::index_t tokens = 208; @@ -196,9 +200,9 @@ int main(int argc, char* argv[]) // ck::index_t sorted_tile_num = 259; // ck::index_t valid_tile_num = 256; // ck::index_t tokens = 4096; - ck::index_t sorted_tile_num = 2; - ck::index_t valid_tile_num = 2; - ck::index_t tokens = 32; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 16; + ck::index_t tokens = 4; #else // deepseek ck::index_t N = 2048; @@ -209,7 +213,7 @@ int main(int argc, char* argv[]) ck::index_t sorted_tile_num = 261; ck::index_t valid_tile_num = 256; #endif - ck::index_t KBatch = 6; + ck::index_t KBatch = 1; if(argc == 1) { // use default case diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index c4c70009d5..37d296aa91 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,7 +36,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} -SUPPORTED_PAGE_SIZE = [128, 256, 1024] +SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024] SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] KV_MEMORY_LAYOUT_ENUM_MAP = { @@ -737,6 +737,8 @@ def get_fwd_blobs( # Generate kernels for both page_size=16 and page_size=1024 for page_size in SUPPORTED_PAGE_SIZE: + if page_size == 1 and pipeline.F_kv_memory_layout != "linear": + continue k = FmhaFwdKernel( F_idx=0, F_hdim=hdim, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 536fcb0692..7e1fa3e0a8 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1351,8 +1351,8 @@ fwd_result fmha_fwd_run(mode_enum mode, auto oacc_element_func = [&]() { if constexpr(std::is_same_v && supports_qscale) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o_host}); + return ck_tile::make_composes(ck_tile::saturates{}, + ck_tile::scales{scale_o_host}); else if constexpr(supports_qscale) return ck_tile::scales{scale_o_host}; else diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 715ed35394..074b594534 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -15,6 +15,22 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) +# Multi Reduce Threadwise Example +set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise") +add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS}) + +# Multi Reduce Blockwise Example +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock") +add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS}) + # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp new file mode 100644 index 0000000000..2384dc2aa5 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp @@ -0,0 +1,271 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "19", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = float; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + std::vector h(number_operations * N * C); + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + // Operations: one doing a sum reduction, the other computing the mean square + // In the case of mean square: + // 1. The element wise operation squares each element before reduction + // 2. The reduction operation sum the squared element + // 3. The accumulator element wise operation divides the result by the total number of reduced + // elements (intra block operation) + // 4. The partial result is updated across blocks using inter block reduction, a sum. + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions + auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise + // ops + auto accumulator_elementwise_ops = ck_tile::make_tuple( + ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block + auto inter_block_reduce_ops = ck_tile::make_tuple( + ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceMultiblock; + + // Determine block group size for multi-block reduction + // block_group_size records how many blocks participate to a reduction (input data dependent) + // , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient + // to process the whole reduction, each thread will to process multiple thread tile + // a num_block_tile_iterations times + auto [num_block_tile_iterations, block_group_size] = + typename Kernel::TilePartitioner{reduce_total_length}.GetBlockGroupParams(); + + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + ck_tile::index_t kGridSize = + ((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size; + + std::cout << "Block group size: " << block_group_size + << ", Num block tile iterations: " << num_block_tile_iterations + << ", Reduce total length: " << reduce_total_length << std::endl; + std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl; + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + // Init the output data with identity values respective to each reduce op + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + constexpr auto op = reduce_ops.at(i); + const auto identity_val = op.template GetIdentityValue(); + const auto output_number_elements = N * C; + std::fill(h.begin() + i * output_number_elements, + h.begin() + (i + 1) * output_number_elements, + identity_val); + }); + + auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); }; + + float ave_time = launch_kernel_time_mask( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + clear_output_buffer, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops) + + ); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_multiple_reduce_multiblock( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops, + block_group_size); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + std::cout << "Checking operation " << i << ": " << std::endl; + + bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + + if(pass_op) + { + std::cout << "✅ valid results for this operation" << std::endl; + } + pass &= pass_op; + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp new file mode 100644 index 0000000000..c929a7eb82 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp @@ -0,0 +1,224 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "7", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + // Two operations: one do a sum reduction, the other computing the mean square + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops + auto elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise ops + auto accumulator_elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementiwise ops on reduction, + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceThreadWise; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops)); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::vector h(number_operations * N * C); + + // reference + ck_tile::reference_multiple_reduce( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 940ee3e503..1156de0e9c 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -45,6 +45,11 @@ cmake .. ``` +Note: The tests for WMMA builders are only built when `CK_USE_WMMA` is enabled. Add e.g. +`gfx1121` or any of the other `gfx11`/`gfx12` architectures to the GPU targets. Alternatively, +one can add flag `-D CK_USE_WMMA=ON` to build the tests. For the end-to-end tests that use +the instances from builder, one needs an actual Navi card. + ## Building and Testing The builder test suite is organized into two main categories: diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index 8075e33220..af8c4ec01b 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; // 1, 2, or 3 - { t.data_type } -> std::convertible_to; // Default data type { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; requires ConvolutionDirectionWellDefinedIfProvided; // Optional direction + requires detail::DataTypeWellDefinedIfProvided; // Optional default data type + requires detail::ElementwiseOpWellDefinedIfProvided; // Optional default elementwise operation }; ``` **Properties:** - **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D) -- **`direction`**: Operation type (optional, defaults to FORWARD) +- **`direction`**: Operation type (Optional, defaults to FORWARD) - `FORWARD`: Standard forward convolution - `BACKWARD_DATA`: Gradient computation w.r.t. input - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights -- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8) +- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors) +- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors) - **`accumulation_data_type`**: Type used for internal accumulation #### 2. Tensor Level @@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) { A tensor descriptor encapsulates: - **Configuration**: Layout and data type information -- **Operation** (optional): Fused elementwise operations on this tensor +- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor) #### 3. Tensor Configuration @@ -126,7 +128,7 @@ Describes the memory layout and data types: template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - { t.data_type } -> std::convertible_to; // Optional override + requires detail::DataTypeWellDefinedIfProvided; // Override data type (Optional, default provided by ConvSignatureDescriptor) }; ``` diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index bf7e89fcaa..791924ccd4 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -15,29 +15,31 @@ namespace ck_tile::builder { /* Descriptors for individual elements of the algorithm description */ /********************************************************************/ +// Common concept for size-related fields +template +concept SizeType = std::unsigned_integral>; + // Concept for thread block dimensions for a GEMM problem. template concept ThreadBlockDescriptor = requires(T t) { - { t.block_size } -> std::convertible_to; - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.block_size } -> SizeType; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for parameters that describe a gridwise XDL GEMM problem. template concept GridwiseXdlGemmDescriptor = requires(T t) { - { t.ak1 } -> std::convertible_to; - { t.bk1 } -> std::convertible_to; - { t.m_per_xdl } -> std::convertible_to; - { t.n_per_xdl } -> std::convertible_to; - { t.m_xdl_per_wave } -> std::convertible_to; - { t.n_xdl_per_wave } -> std::convertible_to; + { t.m_per_xdl } -> SizeType; + { t.n_per_xdl } -> SizeType; + { t.m_xdl_per_wave } -> SizeType; + { t.n_xdl_per_wave } -> SizeType; }; // Concept for parameter that describe block GEMM problem. template -concept BlockGemmDescriptor = requires(T t) { +concept BlockGemmPipelineDescriptor = requires(T t) { { t.pipeline_version } -> std::convertible_to; { t.scheduler } -> std::convertible_to; }; @@ -45,37 +47,48 @@ concept BlockGemmDescriptor = requires(T t) { // Concept for parameters that describe a gridwise WMMA GEMM problem. template concept GridwiseWmmaGemmDescriptor = requires(T t) { - { t.k1 } -> std::convertible_to; - { t.m_per_wmma } -> std::convertible_to; - { t.n_per_wmma } -> std::convertible_to; - { t.m_wmma_per_wave } -> std::convertible_to; - { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.k1 } -> SizeType; + { t.m_per_wmma } -> SizeType; + { t.n_per_wmma } -> SizeType; + { t.m_wmma_per_wave } -> SizeType; + { t.n_wmma_per_wave } -> SizeType; }; // Concept for vectorized data transfer for convolution input tensors. template -concept BlockTransferDescriptor = requires(T t) { - { t.k0 } -> std::convertible_to; - { t.m_n } -> std::convertible_to; - { t.k1 } -> std::convertible_to; +concept BlockTransferDescriptor3D = requires(T t) { + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; }; +template +concept BlockTransferDescriptor4D = requires(T t) { + { t.k0 } -> SizeType; + { t.m_n } -> SizeType; + { t.k1 } -> SizeType; + { t.k_batch_size } -> SizeType; +}; + +template +concept BlockTransferDescriptor = (ThreadClusterRank == 3 && BlockTransferDescriptor3D) || + (ThreadClusterRank == 4 && BlockTransferDescriptor4D); + // Concept for thread cluster dimensions for GEMM output tensor. template concept ThreadClusterDescriptor = requires(T t) { - { t.m_block } -> std::convertible_to; - { t.m_wave_per_xdl } -> std::convertible_to; - { t.n_block } -> std::convertible_to; - { t.n_wave_per_xdl } -> std::convertible_to; + { t.m_block } -> SizeType; + { t.m_wave_per_xdl } -> SizeType; + { t.n_block } -> SizeType; + { t.n_wave_per_xdl } -> SizeType; }; // Concept for the LDS transfer for the convolution input tensors. template concept LdsTransferDescriptor = requires(T t) { - { t.src_vector_dim } -> std::convertible_to; - { t.src_scalar_per_vector } -> std::convertible_to; - { t.lds_dst_scalar_per_vector } -> std::convertible_to; + { t.src_vector_dim } -> SizeType; + { t.src_scalar_per_vector } -> SizeType; + { t.lds_dst_scalar_per_vector } -> SizeType; { t.is_direct_load } -> std::convertible_to; { t.lds_padding } -> std::convertible_to; }; @@ -84,33 +97,35 @@ concept LdsTransferDescriptor = requires(T t) { // LDS). template concept EpilogueDescriptor = requires(T t) { - { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; - { t.n_per_wave_per_shuffle } -> std::convertible_to; - { t.scalar_per_vector } -> std::convertible_to; + { t.m_xdl_per_wave_per_shuffle } -> SizeType; + { t.n_per_wave_per_shuffle } -> SizeType; + { t.scalar_per_vector } -> SizeType; }; // Concept for the thread cluster access order template concept AccessOrderDescriptor = requires(T t) { { t.order } -> std::convertible_to>; +} || requires(T t) { + { t.order } -> std::convertible_to>; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block // size is deduced from block gemm structure). template concept TileThreadBlockDescriptor = requires(T t) { - { t.tile_size.m } -> std::convertible_to; - { t.tile_size.n } -> std::convertible_to; - { t.tile_size.k } -> std::convertible_to; + { t.tile_size.m } -> SizeType; + { t.tile_size.n } -> SizeType; + { t.tile_size.k } -> SizeType; }; // Concept for thread block dimensions for a GEMM problem for CK Tile (Block // size is deduced from block gemm structure). template concept TileTransferDescriptor = requires(T t) { - { t.a_scalar_per_vector } -> std::convertible_to; - { t.b_scalar_per_vector } -> std::convertible_to; - { t.c_scalar_per_vector } -> std::convertible_to; + { t.a_scalar_per_vector } -> SizeType; + { t.b_scalar_per_vector } -> SizeType; + { t.c_scalar_per_vector } -> SizeType; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -159,30 +174,51 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies gridwise XDL GEMM info. template -concept SpecifiesGridwiseXdlGemm = requires { - { T::gridwise_gemm } -> GridwiseXdlGemmDescriptor; +concept GridwiseFwdXdlGemmDescriptor = requires(T t) { + { t.ak1 } -> SizeType; + { t.bk1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept GridwiseBwdXdlGemmDescriptor = requires(T t) { + { t.k1 } -> SizeType; + { t.xdl_params } -> GridwiseXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseFwdXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseFwdXdlGemmDescriptor; +}; + +// Concept to check if a struct specifies gridwise XDL GEMM info. +template +concept SpecifiesGridwiseBwdXdlGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseBwdXdlGemmDescriptor; }; // Concept to check if a struct specifies gridwise WMMA GEMM info. template -concept SpecifiesGridwiseWmmaGemm = requires { - { T::gridwise_gemm } -> GridwiseWmmaGemmDescriptor; +concept SpecifiesGridwiseWmmaGemm = requires(T t) { + { t.gridwise_gemm } -> GridwiseWmmaGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. -template +template concept SpecifiesBlockTransfer = requires(T t) { - { T::transfer.a.block_transfer } -> BlockTransferDescriptor; - { T::transfer.b.block_transfer } -> BlockTransferDescriptor; + { T::transfer.a.block_transfer } -> BlockTransferDescriptor; + { T::transfer.b.block_transfer } -> BlockTransferDescriptor; { T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor; }; // Concept to check if a struct specifies convolution scalar per vector infor for A, B and C. template concept SpecifiesTileTransfer = requires(T t) { - { T::transfer.a_scalar_per_vector } -> std::convertible_to; - { T::transfer.b_scalar_per_vector } -> std::convertible_to; - { T::transfer.c_scalar_per_vector } -> std::convertible_to; + { T::transfer.a_scalar_per_vector } -> SizeType; + { T::transfer.b_scalar_per_vector } -> SizeType; + { T::transfer.c_scalar_per_vector } -> SizeType; }; // Concept to check if a struct specifies LDS transfer info for tensors A, B, and C. @@ -210,8 +246,12 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm_pipeline } -> BlockGemmPipelineDescriptor; +}; + +template +concept SpecifiesGridwiseGemmPipeline = requires { + { T::pipeline_version } -> std::convertible_to; }; // Concept to check if struct specifies block GEMM (CK Tile). @@ -244,7 +284,12 @@ concept SpecifiesTileConvSpecialization = requires { template concept SpecifiesFwdConvSpecialization = requires { - { T::fwd_specialization } -> std::convertible_to; + { T::fwd_specialization } -> std::convertible_to; +}; + +template +concept SpecifiesBwdWeightConvSpecialization = requires { + { T::bwd_weight_specialization } -> std::convertible_to; }; template @@ -254,12 +299,12 @@ concept SpecifiesGemmSpecialization = requires { template concept SpecifiesNumPrefetchStages = requires { - { T::num_gemm_k_prefetch_stages } -> std::convertible_to; + { T::num_gemm_k_prefetch_stages } -> SizeType; }; template concept SpecifiesNumGroupsToMerge = requires { - { T::num_groups_to_merge } -> std::convertible_to; + { T::num_conv_groups_to_merge } -> SizeType; }; template @@ -267,12 +312,59 @@ concept SpecifiesLoopScheduler = requires { { T::loop_scheduler } -> std::convertible_to; }; +template +concept SpecifiesGenericInstance = !requires { + { T::specialization }; +}; + +template +concept SpecifiesTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector } -> SizeType; + { T::max_transpose_transfer_dst_scalar_per_vector } -> SizeType; +}; + +template +concept HasTransposeTransfer = requires { + { T::max_transpose_transfer_src_scalar_per_vector }; + { T::max_transpose_transfer_dst_scalar_per_vector }; +}; + +template +concept TransposeTransferWellDefinedIfProvided = + !HasTransposeTransfer || SpecifiesTransposeTransfer; + +template +concept SpecifiesGemmBatchOptions = requires { + { T::num_conv_groups_to_merge } -> SizeType; +}; + +/******************************************** */ +/* Algorithm specialization concepts */ +/******************************************** */ template concept SpecifiesLargeTensorSupport = requires { { T::specialization } -> std::convertible_to; requires T::specialization == ConvAlgorithmSpecialization::LARGE_TENSOR; }; +template +concept SpecifiesReferenceAlgorithm = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; +}; + +template +concept SpecifiesTwoStageSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::TWO_STAGE; +}; + +template +concept SpecifiesMultipleDSupport = requires { + { T::specialization } -> std::convertible_to; + requires T::specialization == ConvAlgorithmSpecialization::MULTIPLE_D; +}; + /******************************************** */ /* DL-specific descriptors and requirements */ /******************************************** */ @@ -280,11 +372,11 @@ concept SpecifiesLargeTensorSupport = requires { // Concept for DL thread configuration template concept DlThreadConfigDescriptor = requires(T t) { - { t.k0_per_block } -> std::convertible_to; - { t.k1 } -> std::convertible_to; - { t.m1_per_thread } -> std::convertible_to; - { t.n1_per_thread } -> std::convertible_to; - { t.k_per_thread } -> std::convertible_to; + { t.k0_per_block } -> SizeType; + { t.k1 } -> SizeType; + { t.m1_per_thread } -> SizeType; + { t.n1_per_thread } -> SizeType; + { t.k_per_thread } -> SizeType; }; // Concept for DL thread cluster @@ -295,23 +387,29 @@ concept DlThreadClusterDescriptor = requires(T t) { }; // Concept for DL block transfer -template +template concept DlBlockTransferDescriptor = requires(T t) { - { t.thread_slice_lengths } -> std::convertible_to>; - { t.thread_cluster_lengths } -> std::convertible_to>; - { t.thread_cluster_arrange_order } -> std::convertible_to>; - { t.src_access_order } -> std::convertible_to>; - { t.src_vector_tensor_lengths } -> std::convertible_to>; - { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; - { t.dst_vector_tensor_lengths } -> std::convertible_to>; + { t.thread_slice_lengths } -> std::convertible_to>; + { t.thread_cluster_lengths } -> std::convertible_to>; + { t.thread_cluster_arrange_order } -> std::convertible_to>; + { t.src_access_order } -> std::convertible_to>; + { t.src_vector_tensor_lengths } -> std::convertible_to>; + { t.src_vector_tensor_contiguous_dim_order } -> std::convertible_to>; + { t.dst_vector_tensor_lengths } -> std::convertible_to>; }; +template +concept DlBlockTransferDescriptor4D = DlBlockTransferDescriptor; + +template +concept DlBlockTransferDescriptor5D = DlBlockTransferDescriptor; + // Concept for DL epilogue template concept DlEpilogueDescriptor = requires(T t) { { t.src_dst_access_order } -> std::convertible_to>; - { t.src_dst_vector_dim } -> std::convertible_to; - { t.dst_scalar_per_vector } -> std::convertible_to; + { t.src_dst_vector_dim } -> SizeType; + { t.dst_scalar_per_vector } -> SizeType; }; // Concept to check if algorithm specifies DL thread config @@ -328,15 +426,21 @@ concept SpecifiesDlThreadCluster = requires { // Concept to check if algorithm specifies DL block transfer template -concept SpecifiesDlBlockTransfer = requires { - { T::transfer.a.block_transfer } -> DlBlockTransferDescriptor; - { T::transfer.b.block_transfer } -> DlBlockTransferDescriptor; +concept SpecifiesDlFwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor4D; + { T::transfer.b } -> DlBlockTransferDescriptor4D; +}; + +template +concept SpecifiesDlBwdBlockTransfer = requires { + { T::transfer.a } -> DlBlockTransferDescriptor5D; + { T::transfer.b } -> DlBlockTransferDescriptor5D; }; // Concept to check if algorithm specifies DL C thread transfer template concept SpecifiesDlEpilogue = requires { - { T::transfer.c.epilogue } -> DlEpilogueDescriptor; + { T::transfer.c } -> DlEpilogueDescriptor; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 10a619024a..d35897fc78 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -29,10 +29,20 @@ concept OutputVectorTransferLimits = requires { // Limits for access order. Must be a permutation of {0, 1, 2}. template -concept AccessOrderLimits = requires { +concept AccessOrderLimits3D = requires { requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) && (Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) && - (Value[2] >= 0 && Value[2] < 3)); + (Value[2] >= 0 && Value[2] < 3) && (Value.Size() == 3)); +}; + +// Limits for access order. Must be a permutation of {0, 1, 2, 3}. +template +concept AccessOrderLimits4D = requires { + requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[0] != Value[3]) && + (Value[1] != Value[2]) && (Value[1] != Value[3]) && (Value[2] != Value[3]) && + (Value[0] >= 0 && Value[0] < 4) && (Value[1] >= 0 && Value[1] < 4) && + (Value[2] >= 0 && Value[2] < 4) && (Value[3] >= 0 && Value[3] < 4) && + (Value.Size() == 4)); }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 39e081ec8d..c9cb6fe767 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -80,6 +80,7 @@ concept ConvOutputLayout3D = (L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) || (L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided); +namespace detail { template concept HasDataType = requires(T t) { { t.data_type }; @@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - requires DataTypeWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; }; template @@ -116,7 +118,6 @@ template struct IsArrayOfTensorConfigDescriptors> : std::true_type { }; -} // namespace detail template concept ConvertibleToArrayOfTensorConfigs = @@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) { { t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs; }; }; +} // namespace detail template concept TensorOperatorDescriptor = requires(T t) { { t.elementwise_operation } -> std::convertible_to; - requires AuxiliaryOperandConfigsWellDefinedIfProvided; + requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided; }; template @@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) { { t.operation }; }; +namespace detail { + template concept HasConvolutionDirection = requires(T t) { { t.direction }; @@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail + // Concept for the convolution tensor template concept ConvTensorDescriptor = requires(T t) { { t.config } -> TensorConfigDescriptor; - requires ElementwiseOpWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; template @@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) { { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; - requires ConvolutionDirectionWellDefinedIfProvided; - requires DataTypeWellDefinedIfProvided; + requires detail::ConvolutionDirectionWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; // Concept to validate a convolution signature's values. @@ -221,4 +228,13 @@ concept ValidConvWeightLayoutForSpatialDim = (SpatialDim == 1 && ConvWeightLayout1D) || (SpatialDim == 2 && ConvWeightLayout2D) || (SpatialDim == 3 && ConvWeightLayout3D); +// Constraint for 3D conv signature. +template +concept Is3D = requires { + requires Sig.spatial_dim == 3; + requires ConvInputLayout3D; + requires ConvOutputLayout3D; + requires ConvWeightLayout3D; +}; + } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp new file mode 100644 index 0000000000..fc0ee48ec0 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -0,0 +1,128 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory { + +// Base algorithm concepts +template +concept TileTransferParameters = + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder; + +template +concept SpecifiesTileTransferParameters3D = TileTransferParameters; + +template +concept SpecifiesTileTransferParameters4D = TileTransferParameters; + +template +concept FwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; + +template +concept BwdXdlAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters4D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization; + +template +concept BwdXdlV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; + +template +concept BwdWmmaAlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization; + +template +concept BwdWmmaV3AlgorithmBase = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesBwdWeightConvSpecialization && + SpecifiesBlockGemm; + +// Reference algorithm concept +template +concept ReferenceAlgorithm = ConvAlgorithmDescriptor && SpecifiesReferenceAlgorithm; + +// Tile-based algorithm concept +template +concept TileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; + +// FWD XDL algorithm concepts +template +concept FwdXdlAlgorithm = FwdXdlAlgorithmBase && SpecifiesGenericInstance; + +template +concept LargeTensorAlgorithm = FwdXdlAlgorithmBase && SpecifiesLargeTensorSupport; + +template +concept FwdXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseFwdXdlGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; + +// FWD WMMA algorithm concepts +template +concept FwdWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && + SpecifiesGridwiseWmmaGemm && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline; + +// FWD DL algorithms +template +concept FwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlFwdBlockTransfer && SpecifiesDlEpilogue; + +// BWD weight XDL algorithm concepts +template +concept BwdXdlAlgorithm = + BwdXdlAlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + +template +concept BwdMultiDXdlAlgorithm = BwdXdlAlgorithmBase && SpecifiesMultipleDSupport; + +template +concept BwdXdlV3Algorithm = BwdXdlV3AlgorithmBase && SpecifiesGenericInstance; + +template +concept BwdTwoStageXdlAlgorithm = BwdXdlV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; + +// BWD weight WMMA algorithm concepts +template +concept BwdWmmaAlgorithm = + BwdWmmaAlgorithmBase && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler && + SpecifiesGridwiseGemmPipeline && SpecifiesGenericInstance; + +template +concept BwdMultiDWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesMultipleDSupport; + +template +concept BwdWmmaV3Algorithm = + BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && SpecifiesGenericInstance; + +template +concept BwdTwoStageWmmaV3Algorithm = BwdWmmaV3AlgorithmBase && SpecifiesTransposeTransfer && + SpecifiesGemmBatchOptions && SpecifiesTwoStageSupport; + +// BWD weigth DL algorithms +template +concept BwdDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesBwdWeightConvSpecialization && SpecifiesDlThreadConfig && + SpecifiesDlThreadCluster && SpecifiesDlBwdBlockTransfer && SpecifiesDlEpilogue; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp new file mode 100644 index 0000000000..fda1659c75 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp @@ -0,0 +1,131 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Dl instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightDlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp new file mode 100644 index 0000000000..b02dea9558 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightMultiDWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp new file mode 100644 index 0000000000..4f6812617a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightMultiDXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Layouts::DsLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Types::DsDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp new file mode 100644 index 0000000000..adf108bac4 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp new file mode 100644 index 0000000000..d887c1c1ce --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -0,0 +1,111 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightTwoStageXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + ALGORITHM.num_conv_groups_to_merge, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp new file mode 100644 index 0000000000..4067845291 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight && Is3D +struct ConvBwdWeightWmmaFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + ALGORITHM.num_gemm_k_prefetch_stages, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp new file mode 100644 index 0000000000..027c8a1fba --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightWmmaV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp new file mode 100644 index 0000000000..fbb177f333 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -0,0 +1,103 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffle instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + static_assert(AccessOrderLimits4D); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::OutComputeType, + typename Types::InComputeType, + ALGORITHM.max_transpose_transfer_src_scalar_per_vector, + ALGORITHM.max_transpose_transfer_dst_scalar_per_vector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp new file mode 100644 index 0000000000..66a47c5407 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -0,0 +1,108 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 instance +// of a grouped bwd weight convolution kernel. +template + requires ConvDirectionIsBackwardWeight +struct ConvBwdWeightXdlV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BWD_CONV_SPECIALIZATION = + internal::SetBwdWeightConvSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetBwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); + static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); + static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); + static_assert(AccessOrderLimits3D, + "Invalid A thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid B thread cluster access order"); + static_assert(AccessOrderLimits3D, + "Invalid A source access order"); + static_assert(AccessOrderLimits3D, + "Invalid B source access order"); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + SPATIAL_DIM, + typename Layouts::InLayout, + typename Layouts::WeiLayout, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::OutDataType, + typename Types::AccDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, + BWD_CONV_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + A_BLOCK_TRANSFER.lds_padding, + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + B_BLOCK_TRANSFER.lds_padding, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::OutComputeType, + typename Types::InComputeType>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index c0dd3d8018..e235db4bb0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -57,6 +57,9 @@ #include "ck_tile/builder/conv_algorithm_concepts.hpp" #include "ck_tile/builder/types.hpp" +// Compile time diagnostics +#include "ck_tile/builder/factory/conv_algorithms.hpp" + // Include all factory implementations #include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" #include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" @@ -65,6 +68,15 @@ #include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" #include "ck_tile/builder/factory/reference_factory.hpp" #include "ck_tile/builder/factory/conv_tile_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_dl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp" namespace ck_tile::builder::factory { @@ -87,56 +99,6 @@ namespace ck_tile::builder::factory { // // TODO: Make this dispatch logic much more robust and clear for users. -// Reference algorithm (simplest implementation for validation) -template -concept IsReferenceAlgorithm = ConvAlgorithmDescriptor && requires { - { T::specialization } -> std::convertible_to; - requires T::specialization == ConvAlgorithmSpecialization::REFERENCE; -}; - -// CK Tile kernel -template -concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && - SpecifiesTileTransfer && SpecifiesTileConvSpecialization && - SpecifiesTileBlockGemm && SpecifiesTileOptimizations; - -// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) -template -concept IsXdlV3Algorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; - -// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) -template -concept IsXdlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; - -// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) -template -concept IsWmmaAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts -template -concept IsDlAlgorithm = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -// XDL-based kernel with large tensor support -template -concept IsLargeTensorAlgorithm = - IsXdlAlgorithm && SpecifiesLargeTensorSupport; - template @@ -145,35 +107,35 @@ constexpr auto make_conv_instance() using AlgoType = std::remove_const_t; // Reference algorithm supports all directions - if constexpr(IsReferenceAlgorithm) + if constexpr(ReferenceAlgorithm) { return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - else if constexpr(IsTileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } // Forward direction (supports most algorithm variants) else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm) + if constexpr(FwdXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm) + else if constexpr(FwdXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm) + else if constexpr(FwdWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm) + else if constexpr(FwdDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm) + else if constexpr(LargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } @@ -197,10 +159,55 @@ constexpr auto make_conv_instance() // Backward weight direction (will expand with more algorithms in the future) else if constexpr(ConvDirectionIsBackwardWeight) { - static_assert(false, - "Backward weight convolution: Only reference and tile algorithms " - "supported currently. " - "Optimized kernels (XDL, WMMA, etc.) not yet implemented."); + if constexpr(BwdXdlAlgorithm) + { + return typename ConvBwdWeightXdlFactory::Instance{}; + } + else if constexpr(BwdXdlV3Algorithm) + { + return typename ConvBwdWeightXdlV3Factory::Instance{}; + } + else if constexpr(BwdTwoStageXdlAlgorithm) + { + return + typename ConvBwdWeightTwoStageXdlFactory::Instance{}; + } + else if constexpr(BwdDlAlgorithm) + { + return typename ConvBwdWeightDlFactory::Instance{}; + } + else if constexpr(BwdMultiDXdlAlgorithm) + { + return + typename ConvBwdWeightMultiDXdlFactory::Instance{}; + } + else if constexpr(BwdWmmaV3Algorithm) + { + return typename ConvBwdWeightWmmaV3Factory::Instance{}; + } + else if constexpr(BwdTwoStageWmmaV3Algorithm) + { + return typename ConvBwdWeightTwoStageWmmaV3Factory:: + Instance{}; + } + else if constexpr(BwdWmmaAlgorithm) + { + return typename ConvBwdWeightWmmaFactory::Instance{}; + } + else if constexpr(BwdMultiDWmmaV3Algorithm) + { + return typename ConvBwdWeightMultiDWmmaV3Factory:: + Instance{}; + } + else + { + static_assert( + false, + "No suitable backward weight convolution kernel factory found for the provided " + "ALGORITHM. The ALGORITHM must satisfy requirements for one of: Reference, Tile, " + "XDL, XDL V3, Two-Stage XDL, DL, Multi-D XDL, WMMA V3, Two-Stage " + "WMMA V3, WMMA, or Multi-D WMMA V3 variant."); + } } else { diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp index ca202aabfd..1d55772dd6 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -24,10 +24,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -48,7 +48,7 @@ struct ConvFwdDlFactory using M1N1ThreadClusterN1Xs = to_sequence_v; // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a; using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = to_sequence_v; using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = @@ -64,7 +64,7 @@ struct ConvFwdDlFactory to_sequence_v; // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b; using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = to_sequence_v; using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = @@ -80,7 +80,7 @@ struct ConvFwdDlFactory to_sequence_v; // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c; using CThreadTransferSrcDstAccessOrder = to_sequence_v; static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; static constexpr ck::index_t CThreadTransferDstScalarPerVector = @@ -89,18 +89,18 @@ struct ConvFwdDlFactory // The DL forward convolution kernel class instance using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< SPATIAL_DIM, - typename Types::ADataType, - typename Types::BDataType, - typename Types::DsDataTypes, - typename Types::EDataType, + typename Types::InDataType, + typename Types::WeiDataType, + typename Types::DsDataType, + typename Types::OutDataType, typename Types::AccDataType, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Layouts::OutLayout, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, FWD_CONV_SPECIALIZATION, GEMM_SPECIALIZATION, BLOCK.block_size, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index fadf41f48a..0ff410d731 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -26,68 +26,65 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); - static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; - - static constexpr auto FWD_CONV_SPECIALIZATION = - internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; - static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); - static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); + internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = - internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - internal::SetCBlockTransfer(); + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); // Check limits for the algorithm parameters. static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance with large tensor support. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, - BASE_ALGORITHM.num_gemm_k_prefetch_stages, + ALGORITHM.num_gemm_k_prefetch_stages, BLOCK.block_size, BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -106,8 +103,8 @@ struct ConvFwdLargeTensorFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 89787cc1b3..dd2fa65eae 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == ALGORITHM.transfer.b.lds_transfer.is_direct_load, @@ -43,6 +43,7 @@ struct ConvFwdXdlV3Factory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -55,27 +56,27 @@ struct ConvFwdXdlV3Factory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, BLOCK.block_size, @@ -84,10 +85,10 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -108,8 +109,8 @@ struct ConvFwdXdlV3Factory C_BLOCK_TRANSFER.scalar_per_vector, BLOCK_GEMM.scheduler, BLOCK_GEMM.pipeline_version, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, IS_DIRECT_LOAD>; }; diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index bb84479071..2d6f7c394b 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -52,27 +52,27 @@ struct ConvFwdWmmaFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 8ec5c633ce..e03e035969 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -26,10 +26,10 @@ template ; - using Types = internal::FwdConvTensorDataTypes; - using Ops = internal::ElementwiseOps; - using AlgorithmType = decltype(ALGORITHM); + using Layouts = internal::ConvTensorLayouts; + using Types = internal::ConvTensorDataTypes; + using Ops = internal::ConvElementwiseOps; + using AlgorithmType = decltype(ALGORITHM); static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); @@ -39,6 +39,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -50,27 +51,27 @@ struct ConvFwdXdlFactory static_assert(InputVectorTransferLimits); static_assert(InputVectorTransferLimits); static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); + static_assert(AccessOrderLimits3D); // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, + typename Layouts::InLayout, + typename Layouts::WeiLayout, typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, + typename Layouts::OutLayout, + typename Types::InDataType, + typename Types::WeiDataType, typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, + typename Types::OutComputeType, + typename Types::DsDataType, + typename Types::OutDataType, + typename Ops::InElementwiseOp, + typename Ops::WeiElementwiseOp, + typename Ops::OutElementwiseOp, SPECIALIZATION.conv_spec, SPECIALIZATION.gemm_spec, ALGORITHM.num_gemm_k_prefetch_stages, @@ -80,10 +81,10 @@ struct ConvFwdXdlFactory BLOCK.per_block.k, GRIDWISE_GEMM.ak1, GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, @@ -102,10 +103,10 @@ struct ConvFwdXdlFactory C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, to_sequence_v, C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, + typename Types::InComputeType, + typename Types::WeiComputeType, LOOP_SCHEDULER, - ALGORITHM.num_groups_to_merge>; + ALGORITHM.num_conv_groups_to_merge>; }; } // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 5da1e4eadb..d873a4b903 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -10,27 +10,28 @@ namespace ck_tile::builder::factory::internal { // Block transfer parameters for A or B tensor. +template struct BlockTransfer { - ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 - ck::Array thread_cluster_order = {0, 0, 0}; - ck::Array src_access_order = {0, 0, 0}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; + ck::Array thread_cluster_dims{}; + ck::Array thread_cluster_order{}; + ck::Array src_access_order{}; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; }; template -constexpr BlockTransfer SetFwdConvBlockTransfer() +constexpr BlockTransfer<> SetFwdConvBlockTransfer() { auto& block_xfer = TRANSFER.block_transfer; auto& block_order = TRANSFER.block_transfer_access_order; auto& src_order = TRANSFER.src_access_order; auto& lds_cfg = TRANSFER.lds_transfer; - return BlockTransfer{ + return BlockTransfer<>{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, @@ -42,6 +43,59 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() }; } +template +constexpr auto SetBwdConvBlockTransfer() +{ + auto& block_xfer = TRANSFER.block_transfer; + auto& block_order = TRANSFER.block_transfer_access_order; + auto& src_order = TRANSFER.src_access_order; + auto& lds_cfg = TRANSFER.lds_transfer; + + constexpr auto array_length = block_order.order.size(); + static_assert(block_order.order.size() == src_order.order.size(), + "Mismatched size between block order and src order"); + + if constexpr(array_length == 3) + { + return BlockTransfer<3>{ + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, + }; + } + else if constexpr(array_length == 4) + { + return BlockTransfer<4>{ + .thread_cluster_dims = {block_xfer.k_batch_size, + block_xfer.k0, + block_xfer.m_n, + block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], + block_order.order[1], + block_order.order[2], + block_order.order[3]}, + .src_access_order = {src_order.order[0], + src_order.order[1], + src_order.order[2], + src_order.order[3]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .lds_padding = lds_cfg.lds_padding, + }; + } + else + { + static_assert(false, "Internal error: Unsupported array length"); + } +} + // Block transfer parameters for C tensor. struct CBlockTransfer { diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp index a39cd7410b..0cc43fc679 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp @@ -62,14 +62,15 @@ consteval auto GetElementwiseOp() } template -struct ElementwiseOps +struct ConvElementwiseOps { static constexpr auto input_op = GetElementwiseOp(); static constexpr auto weight_op = GetElementwiseOp(); static constexpr auto output_op = GetElementwiseOp(); - using AElementwiseOp = typename decltype(input_op)::Op; - using BElementwiseOp = typename decltype(weight_op)::Op; - using CDEElementwiseOp = typename decltype(output_op)::Op; + + using InElementwiseOp = typename decltype(input_op)::Op; + using WeiElementwiseOp = typename decltype(weight_op)::Op; + using OutElementwiseOp = typename decltype(output_op)::Op; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp index a6c0b48c54..fd6de9ae21 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp @@ -190,7 +190,7 @@ consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence) decltype(TensorLayoutToCK())...>{}; } -template +template requires(ConvSpatialDim) struct AuxiliaryTensorLayouts { @@ -200,34 +200,32 @@ struct AuxiliaryTensorLayouts }; // TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias). -template +template requires(HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return AuxiliaryTensorLayouts{}; + SPATIAL_DIM>{}; } -template +template requires(!HasElementwiseOpWithAuxiliaryOperands) consteval auto GetAuxiliaryTensorLayouts() { return EmptyAuxiliaryTensorLayout{}; } -template +template requires(ConvSpatialDim && ValidConvInputLayoutForSpatialDim && ValidConvWeightLayoutForSpatialDim && ValidConvOutputLayoutForSpatialDim) struct ConvTensorLayouts { - static_assert(DIR == ConvDirection::FORWARD, "Only Forward convolution is supported."); - using ALayout = decltype(TensorLayoutToCK()); - using BLayout = decltype(TensorLayoutToCK()); - using ELayout = decltype(TensorLayoutToCK()); - using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; + using InLayout = decltype(TensorLayoutToCK()); + using WeiLayout = decltype(TensorLayoutToCK()); + using OutLayout = decltype(TensorLayoutToCK()); + using DsLayout = decltype(GetAuxiliaryTensorLayouts())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 9430573cc6..0c017e0c47 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -33,7 +33,7 @@ struct DataTypeToCK using type = float; }; template <> -struct DataTypeToCK +struct DataTypeToCK { using type = int32_t; }; @@ -156,7 +156,7 @@ consteval auto GetAuxiliaryTensorDataTypes() } template -struct FwdConvTensorDataTypes +struct ConvTensorDataTypes { static constexpr auto input_types = GetTensorDataAndComputeTypes(); @@ -165,20 +165,17 @@ struct FwdConvTensorDataTypes static constexpr auto output_types = GetTensorDataAndComputeTypes(); - using ADataType = typename decltype(input_types.first)::type; - using AComputeType = typename decltype(input_types.second)::type; - using BDataType = typename decltype(weight_types.first)::type; - using BComputeType = typename decltype(weight_types.second)::type; + using InDataType = typename decltype(input_types.first)::type; + using InComputeType = typename decltype(input_types.second)::type; + using WeiDataType = typename decltype(weight_types.first)::type; + using WeiComputeType = typename decltype(weight_types.second)::type; + using OutDataType = typename decltype(output_types.first)::type; + using OutComputeType = typename decltype(output_types.second)::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; - using EDataType = typename decltype(output_types.first)::type; - - // This is the "compute" type for output. - using CShuffleDataType = typename decltype(output_types.second)::type; - // Data types for the auxiliary tensors (e.g., bias). - using DsDataTypes = typename decltype(GetAuxiliaryTensorDataTypes())::type; + using DsDataType = typename decltype(GetAuxiliaryTensorDataTypes())::type; }; } // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index db741f2112..9ed1eebc3c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -37,7 +38,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm; + constexpr auto& BG = ALGORITHM.block_gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; @@ -82,7 +83,7 @@ consteval ck::LoopScheduler SetLoopScheduler() template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { - constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; + constexpr auto pipeline_version = ALGORITHM.pipeline_version; using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { @@ -149,12 +150,30 @@ consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdC using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(specialization) { - case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; - case ConvFwdSpecialization::ODD_C: return ck_conv_spec::OddC; - default: throw "Unknown ConvFwdSpecialization"; + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + default: throw "Unsupported ConvSpecialization"; + } +} + +template +consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization +SetBwdWeightConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.bwd_weight_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + switch(specialization) + { + case ConvSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + case ConvSpecialization::FILTER_3x3: + throw "FILTER_3x3 is not supported for backward weight convolution."; + default: throw "Unsupported ConvSpecialization"; } } diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index 0748725c96..f6fc2dbda8 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -26,11 +26,11 @@ struct ReferenceFactory static constexpr auto kValidation = (internal::ValidateReferenceSignature(), 0); static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Types = internal::FwdConvTensorDataTypes; + using Types = internal::ConvTensorDataTypes; - using InDataType = typename Types::ADataType; - using WeiDataType = typename Types::BDataType; - using OutDataType = typename Types::EDataType; + using InDataType = typename Types::InDataType; + using WeiDataType = typename Types::WeiDataType; + using OutDataType = typename Types::OutDataType; struct Instance { diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 46c9bb488e..a7b6c60a73 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -63,10 +63,7 @@ struct GemmAlgorithmInfo OutputTileTransferInfo c_tile_transfer; builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; - std::variant - conv_specialization; + builder::ConvSpecialization conv_specialization; builder::GemmPadding padding; }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index a91abd1a46..8caa11618e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -197,18 +197,16 @@ constexpr builder::ConvDirection conv_direction() /// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or -/// `builder::ConvBwdWeightSpecialization` enum value. +/// @return A `builder::ConvSpecialization` enum value. template constexpr auto conv_spec() { using InstTraits = InstanceTraits; + using enum builder::ConvSpecialization; if constexpr(requires { InstTraits::kConvForwardSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; - using enum builder::ConvFwdSpecialization; - switch(InstTraits::kConvForwardSpecialization) { case Default: return DEFAULT; @@ -221,8 +219,6 @@ constexpr auto conv_spec() else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - using enum builder::ConvBwdDataSpecialization; - switch(InstTraits::kConvBwdDataSpecialization) { case Default: return DEFAULT; @@ -232,8 +228,6 @@ constexpr auto conv_spec() else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) { using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; - using enum builder::ConvBwdWeightSpecialization; - switch(InstTraits::kConvBwdWeightSpecialization) { case Default: return DEFAULT; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp index b2e8bb6a7c..6875e586cd 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp @@ -35,10 +35,10 @@ struct ReferenceCommonTraits typename builder::factory::internal::LayoutToCK::type; // Data types - extract from factory's type helper - using Types = builder::factory::internal::FwdConvTensorDataTypes; - using ADataType = typename Types::ADataType; - using BDataType = typename Types::BDataType; - using EDataType = typename Types::EDataType; + using Types = builder::factory::internal::ConvTensorDataTypes; + using ADataType = typename Types::InDataType; + using BDataType = typename Types::WeiDataType; + using EDataType = typename Types::OutDataType; using AccDataType = float; // Reference uses float accumulation // Elementwise operations - reference only supports PassThrough diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index 8cbafa7efa..d8910152dd 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -7,6 +7,7 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/tensor_initialization.hpp" @@ -71,11 +72,10 @@ struct Args using OutputDescriptor = TensorDescriptor; // TODO: We shouldn't need to call into an internal namespace here. - using Ops = factory::internal::ElementwiseOps; + using Ops = factory::internal::ConvElementwiseOps; // TODO: We shouldn't need to call into an internal namespace here. - using Layouts = - factory::internal::ConvTensorLayouts; + using Layouts = factory::internal::ConvTensorLayouts; ConvTensorLengths lengths; @@ -89,9 +89,9 @@ struct Args FilterExtent input_left_pad; FilterExtent input_right_pad; - Ops::AElementwiseOp a_elementwise_op; - Ops::BElementwiseOp b_elementwise_op; - Ops::CDEElementwiseOp cde_elementwise_op; + Ops::InElementwiseOp a_elementwise_op; + Ops::WeiElementwiseOp b_elementwise_op; + Ops::OutElementwiseOp cde_elementwise_op; /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then @@ -106,7 +106,7 @@ struct Args // function. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< - typename Layouts::ALayout>(param); + typename Layouts::InLayout>(param); using Extent = typename InputDescriptor::Extent; return InputDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); @@ -120,7 +120,7 @@ struct Args // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< - typename Layouts::BLayout>(param); + typename Layouts::WeiLayout>(param); using Extent = typename WeightDescriptor::Extent; return WeightDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); @@ -134,7 +134,7 @@ struct Args // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< - typename Layouts::ELayout>(param); + typename Layouts::OutLayout>(param); using Extent = typename OutputDescriptor::Extent; return OutputDescriptor(Extent::from_vector(desc.GetLengths()), Extent::from_vector(desc.GetStrides())); @@ -182,6 +182,12 @@ struct Inputs { void* input; void* weight; + + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("weight", args.make_weight_descriptor(), &Inputs::weight); + } }; /// @brief `Outputs` specialization for forward convolution. @@ -194,68 +200,13 @@ template struct Outputs { void* output; -}; -/// @brief `UniqueInputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see UniqueInputs -/// @see ValidUniqueInputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct UniqueInputs -{ - DeviceBuffer input_buf; - DeviceBuffer weight_buf; - - /// @see ValidUniqueInputs - Inputs get() + static void reflect(const Args& args, const auto& inspect) { - return { - .input = input_buf.get(), - .weight = weight_buf.get(), - }; + inspect("output", args.make_output_descriptor(), &Outputs::output); } }; -/// @brief `UniqueOutputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see UniqueOutputs -/// @see ValidUniqueOutputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct UniqueOutputs -{ - DeviceBuffer output_buf; - - /// @see ValidUniqueOutputs - Outputs get() - { - return { - .output = output_buf.get(), - }; - } -}; - -/// @brief `alloc_inputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_inputs() -template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueInputs -UniqueInputs alloc_inputs(const Args& args) -{ - return { - .input_buf = alloc_tensor_buffer(args.make_input_descriptor()), - .weight_buf = alloc_tensor_buffer(args.make_weight_descriptor()), - }; -} - /// @brief `init_inputs()` specialization for forward convolution. /// /// @tparam SIGNATURE Forward convolution signature. @@ -269,34 +220,4 @@ void init_inputs(const Args& args, Inputs inputs) init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); } -/// @brief `alloc_outputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_outputs() -template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueOutputs -UniqueOutputs alloc_outputs(const Args& args) -{ - return { - .output_buf = alloc_tensor_buffer(args.make_output_descriptor()), - }; -} - -/// @brief `validate()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see validate() -template - requires ValidConvSignature && ConvDirectionIsForward -ValidationReport -validate(const Args& args, Outputs actual, Outputs expected) -{ - ValidationReport report; - report.check("output", args.make_output_descriptor(), actual.output, expected.output); - return report; -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index 499e0ef3de..a90f53ba7d 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -27,7 +27,7 @@ template > + typename Ops = factory::internal::ConvElementwiseOps> concept CkConvInstance = requires(Conv& conv, // TODO: This should be changed depending on IsMultiA etc. // Currently that is not yet supported elsewhere anyway. @@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv, std::array lengths, std::array strides, std::array filter, - Ops::AElementwiseOp elementwise_a, - Ops::BElementwiseOp elementwise_b, - Ops::CDEElementwiseOp elementwise_cde) { + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { { conv.MakeArgument(p_a, p_b, diff --git a/experimental/builder/include/ck_tile/builder/testing/debug.hpp b/experimental/builder/include/ck_tile/builder/testing/debug.hpp new file mode 100644 index 0000000000..4014d62d48 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/debug.hpp @@ -0,0 +1,634 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck/utility/type_convert.hpp" +#include +#include +#include +#include +#include +#include +#include + +/// This file contains a few debugging utilities, mainly focused around +/// tensor data. The idea is that the functionality in this file is not +/// necessarily used in any testing directly, but is available for the +/// programmer to help with debugging problems. These utilities themselves +/// should be tested just the same, though, so that they don't undergo +/// bitrot while they are not actively being used. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Custom number punctuation for CK-Builder debugging. +/// +/// During debugging, the locale is usually left to the default C locale. +/// The C locale does not have any thousands separator, which makes +/// large numbers hard to read. This is a specialization of the default +/// C++ number punctuation (`std::numpunct`) which separates thousands +/// using `'`, which helps getting a quick overview of the magnitude of +/// a number. This character is chosen because C++14 allows number literals +/// to have this character. +/// +/// @note When using this locale, be sure to restore the old locale in the +/// event that the user actually wants to use a non-standard locale. +/// +/// @see std::numpunct +struct numpunct : std::numpunct +{ + char do_thousands_sep() const override { return '\''; } + + std::string do_grouping() const override + { + // See std::numpunct, this separates by thousands. + return "\3"; + } +}; + +} // namespace detail + +/// @brief Print information about a tensor descriptor. +/// +/// This function dumps useful information from a tensor descriptor to a +/// stream, `std::cout` by default. This includes the number of elements +/// in the tensor, the size of the backing space, lengths, strides, etc. +/// +/// @note All information is printed using a lightly modified locale to +/// get a unified printing experience. The original locale in `stream` is +/// temporarily replaced, but restored before the function returns. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param name A name for the tensor descriptor. +/// @param desc The tensor descriptor to print. +/// @param out The stream to print to, `std::cout` by default. +template +void print_descriptor(std::string_view name, + const TensorDescriptor& desc, + std::ostream& out = std::cout) +{ + // Create a custom stream with a completely new config (locale, + /// precision, fill, etc). Use an osyncstream to buffer the output + /// while were at it (its not likely to help a lot, but why not). + std::osyncstream stream(out.rdbuf()); + stream.imbue(std::locale(std::locale(), new detail::numpunct{})); + + // Print name along with some generic info + const auto size = desc.get_element_size(); + const auto space = desc.get_element_space_size(); + const auto bytes = desc.get_element_space_size_in_bytes(); + const auto packed = desc.is_packed(); + + stream << "Descriptor \"" << name << "\":\n" + << " data type: " << DT << '\n' + << " size: " << size << " elements\n" + << " space: " << space << " elements (" << bytes << " bytes)\n" + << " lengths: " << desc.get_lengths() << '\n' + << " strides: " << desc.get_strides() << '\n' + << " packed: " << (packed ? "yes" : "no") << std::endl; +} + +/// @brief User configuration for printing tensors. +/// +/// This structure houses some configuration fields for customizing how tensors +/// are printed. The default is usually good, though `TensorPrintConfig::unlimited()` +/// is useful if you want to print the entire tensor to the output regardless of size. +struct TensorPrintConfig +{ + /// @brief A limit for the number of columns in a tensor row to print. + /// + /// Each row of a tensor will be printed as a sequence of values. At most + /// this number of values are printed, if there are more, `row_skip_val` + /// will be printed in between. + size_t col_limit = 10; + + /// @brief A limit for the number of rows in a 2D matrix to print + /// + /// Tensors with rank higher than 1 are printed as a single matrix or a series + /// of matrix slices. At most this number of rows of the matrix will be printed. + /// If there are more rows, a row of `matrix_row_skip_val` and possibly + /// `row_skip_val` will be printed in between. + size_t row_limit = 10; + + /// @brief A limit for the number of 2D tensor slices to print. + /// + /// Tensors with rank higher than 2 are flattened into a sequence of slices. At + /// most this number of slices will be printed. + size_t slice_limit = 8; + + /// @brief Text to print at the start of a row of values. + /// + /// This is used by `TensorPrinter`, and printed at the start of a row of tensor + /// values. + std::string_view row_prefix = " "; + + /// @brief Text to print between fields of a row. + /// + /// This is used by `TensorPrinter`, and printed between each value of a row of + /// tensor values. + std::string_view row_field_sep = " "; + + /// @brief Text to print when skipping some number of row values. + /// + /// This is used by `TensorPrinter`, and printed instead of some number of values + /// when the number of values in a row is too large to all print. + std::string_view row_skip_val = "..."; + + /// @brief Text to print when skipping a row of a matrix. + /// + /// This is used by `TensorPrinter`, and printed instead of a value when some + /// number of rows is skipped when printing a matrix. This is similar to + /// `row_skip_val`, except in the vertical direction. Note that ALL values + /// in the skip row is printed this way. + std::string_view matrix_row_skip_val = "..."; + + /// @brief The precision of tensor floating point values. + /// + /// Set the number of decimal digits that is printed for a floating point value. + int float_precision = 3; + + /// @brief Return the default print config, but without any printing limits. + /// + /// This is useful if you want to print the *entire* tensor, but be aware that + /// this may print a lot of data if the tensor is large! + constexpr static TensorPrintConfig unlimited() + { + return { + .col_limit = std::numeric_limits::max(), + .row_limit = std::numeric_limits::max(), + .slice_limit = std::numeric_limits::max(), + }; + } +}; + +namespace detail { + +/// @brief Iterate over a range of values, but limit the amount of iterations. +/// +/// Iterate over values `0..n`, but if `limit > n`, only iterate over the +/// first and last few (`limit // 2)` items. This can be used to iterate over +/// large ranges in a way that not too many values are visited. Its primarily +/// used when printing tensors so that not all values of a giant tensor are +/// dumped to the user's terminal. +/// +/// @param n The total number of items to iterate over. +/// @param limit The maximum number of items to iterate over. Use even values +/// for best results, as this will lead to the same amount of values in the +/// "begin" and "end" sections. +/// @param f A functor to invoke for each element. The sole parameter is the +/// index. +/// @param delim A functor to invoke between the begin and end sections. This +/// function is only invoked if any items are skipped at all. +void limited_foreach(size_t n, size_t limit, auto f, auto delim) +{ + if(n <= limit) + { + for(size_t i = 0; i < n; ++i) + f(i); + } + else + { + const auto begin_count = (limit + 1) / 2; // Round up in case `delim` is odd. + const auto end_count = limit / 2; + const auto skip_count = n - limit; + + for(size_t i = 0; i < begin_count; ++i) + f(i); + + delim(skip_count); + + for(size_t i = n - end_count; i < n; ++i) + f(i); + } +}; + +/// @brief Output stream requirements for use with `TensorPrinter`. +/// +/// The `TensorPrinter` does not write to an ostream directly, but rather writes to +/// a custom stream object. This is mainly so that the user of `TensorPrinter` can +/// get more details than directly with an ostream. Basically, a valid implementation +/// of `TensorPrintStream` exposes 3 things: +/// - A way to print (stringified) tensor elements. +/// - A way to print arbitrary text messages. These are mostly for formatting. This +/// should be implemented using varargs which are directly folded into an ostream, +/// so that functions can be used. +/// - A way to query the max width of any `val` field. +/// +/// @see TensorPrinter for more information. +template +concept TensorPrintStream = requires(Stream& stream, std::string_view val) { + { stream.max_width } -> std::convertible_to; + { stream.val(val) } -> std::same_as; + { stream.msg() } -> std::same_as; + { stream.msg("msg") } -> std::same_as; + { stream.msg(std::setw(3), std::setfill(4), "msg", val) } -> std::same_as; +}; + +/// @brief Utility to print tensors. +/// +/// This structure implements the main logic for printing tensors to a stream. +/// In order to help with formatting, the `TensorPrinter` abstracts over a custom +/// stream type, see `TensorPrintStream`. This type is actually mostly an internal +/// helper and mainly used by `print_tensor`. Its supposed to be constructed +/// manually, but see the field docs for what is required. +/// +/// @tparam DT The data type of the tensor to print. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to print. +/// +/// @see print_tensor +template +struct TensorPrinter +{ + /// The name of this tensor. This will be used during printing to add extra + /// clarity about what the user is seeing. + std::string_view name; + + /// Configuration details of how to print the tensor. This should be able to + /// be specified by the user, but the default is good in most cases. + TensorPrintConfig config; + + /// The lengths of the tensor to print. These values are directly from + /// `TensorDescriptor::get_lengths()`, stored here to avoid querying them + /// repeatedly. + Extent lengths; + + /// The strides of the tensor to print. These values are directly from + /// `TensorDescriptor::get_strides()`, stored here to avoid querying them + /// repeatedly. + Extent strides; + + /// The tensor's backing buffer. This memory should be host-accessible, for + /// example by copying it back to the host first. + const void* h_buffer; + + /// A common stringstream for stringifying tensor values. This is here mostly + /// so that we can cache the internal allocation. + std::stringstream ss; + + /// @brief Low-level tensor value stringifying function. + /// + /// Print value `value` to the stringstream `ss` (member value). This function + /// is the actual low-level printing function that prints each element of the + /// tensor. In order to get a robust printing implementation, the value is written + /// directly into a stringstream, which is then further processed to be actually + /// written to the output. This way, the format doesn't depend on the ostream + /// configuration. + /// + /// @param value The value to print to the stream. + void stringify_value(const void* value) + { + if constexpr(DT == DataType::UNDEFINED_DATA_TYPE) + { + ss << "??"; + return; + } + + using CKType = detail::cpp_type_t
; + const auto ck_value = *static_cast(value); + + if constexpr(DT == DataType::I32 || DT == DataType::I8 || DT == DataType::U8) + ss << ck_value; + else if constexpr(DT == DataType::FP64 || DT == DataType::FP32) + ss << std::fixed << std::setprecision(config.float_precision) << ck_value; + else if constexpr(DT == DataType::FP16 || DT == DataType::BF16 || DT == DataType::FP8 || + DT == DataType::BF8) + ss << std::fixed + << std::setprecision(config.float_precision) + // Note: We are using CK types here (cpp_type_t uses DataTypeToCK), so + // use CK's type_convert function. + << ::ck::type_convert(ck_value); + else + // TODO: Tuple types? Currently not implemented in DataTypeToCK... + static_assert(false, "stringify_value unsupported data type, please implement"); + } + + /// @brief Print the value at an index to a stream. + /// + /// This function reads the value at `index` and prints it to `stream` (using + /// `stream.val(...)`). + /// + /// @param stream The stream to print to. + /// @param index The index in the tensor of the value to print. + void print_value(TensorPrintStream auto& stream, const Extent& index) + { + const auto offset = calculate_offset(index, strides); + const auto* value_ptr = + &static_cast(h_buffer)[offset * data_type_sizeof(DT)]; + + // Reset the stream without allocating. + // ss.str("") allocates... + ss.clear(); + ss.seekg(0); + ss.seekp(0); + stringify_value(value_ptr); + // ss.view() returns a view of the ENTIRE buffer, which may have + // lingering data since we used seekp() and seekg() to reset the + // stream. For some reason std::stringstream works this way... + // Fortunately tellp() returns how many bytes we've actually + // written. + const auto view = ss.view().substr(0, ss.tellp()); + stream.val(view); + } + + /// @brief Print a 1D row to a stream. + /// + /// Print a row of tensor values to the stream. This function is used for both + /// 1D tensors and for rows of 2D tensors, in which the base coordinate is given + /// by `index`. Note that the print configuration is taken into account to avoid + /// flooding the user's terminal with values. + /// + /// @param stream The stream to print to. + /// @param index The index of the row to print. The rightmost index element is + /// ignored, as that is the index of the value _within_ the row. + void print_row(TensorPrintStream auto& stream, Extent& index) + { + // See note in `print_matrix`. + stream.msg(config.row_prefix); + limited_foreach( + lengths[RANK - 1], + config.col_limit, + [&](auto i) { + stream.msg(config.row_field_sep); + index[RANK - 1] = i; + print_value(stream, index); + }, + [&]([[maybe_unused]] auto skip_count) { + stream.msg(config.row_field_sep); + // Note: Not using stream.val(...) here because we don't want this + // field to partake in max_width computation, nor do we want to + // pad it to the max width. + stream.msg(config.row_skip_val); + }); + + stream.msg('\n'); + } + + /// @brief Print a 2D matrix to a stream. + /// + /// Print a matrix of tensor values to the stream. This function is used for both + /// 2D and slices of higher-dimensional tensors, in which the base coordinate is + /// given by `index`. Note that the print configuration is taken into account to + /// avoid flooding the user's terminal with values. + /// + /// @param stream The stream to print to. + /// @param index The index of the row to print. The 2 rightmost index elements are + /// ignored, as those are the indices of values _within_ the matrix. + void print_matrix(TensorPrintStream auto& stream, Extent& index) + { + limited_foreach( + lengths[RANK - 2], + config.row_limit, + [&](auto i) { + index[RANK - 2] = i; + print_row(stream, index); + }, + [&]([[maybe_unused]] auto row_skip_count) { + // When we encounter a skip row, continue with the same logic + // as printing 1D tensor rows. Instead of actual values, we will + // simply print MATRIX_ROW_SKIP_VAL (usually something like "..."). + stream.msg(config.row_prefix); + limited_foreach( + lengths[RANK - 1], + config.col_limit, + [&]([[maybe_unused]] auto i) { + stream.msg(config.row_field_sep); + // Note: We're using `stream.val(...)` here because we *do* want this field + // to partake in max_width computation, and we *do* want to pad it like + // value fields. This is so that these appear the same width as actual + // values, so that everything is neatly aligned. This also ensures that if + // there are no skip values, then the size of the skip field is not taken + // into account. + stream.val(config.matrix_row_skip_val); + }, + [&]([[maybe_unused]] auto col_skip_count) { + stream.msg(config.row_field_sep); + // Note: Not using stream.val(...) here because we don't want this + // field to partake in max_width computation, nor do we want to + // pad it to the max width. + stream.msg(config.row_skip_val); + }); + stream.msg('\n'); + }); + } + + /// @brief Print a tensor to a stream. + /// + /// This is the main tensor printing function. It calls `print_row` or `print_matrix` + /// (possibly repeatedly) as required. This function prints the entire tensor in + /// `h_buffer` regardless. + /// + /// @param stream The stream to print to. + void print_tensor(TensorPrintStream auto& stream) + { + Extent zero_coord = {}; + if constexpr(RANK == 0) + { + // 0D case: just print the one value + stream.msg(config.row_prefix); + stream.msg(config.row_field_sep); + print_value(stream, zero_coord); + stream.msg('\n'); + } + else if constexpr(RANK == 1) + { + // 1D case: dump everything on one line + print_row(stream, zero_coord); + } + else if constexpr(RANK == 2) + { + // 2D case: print a 2D matrix + print_matrix(stream, zero_coord); + } + else + { + // For higher dimensions, print each window as a slice + // We want to limit the *total* number of slices using `slice_limit`, + // not the number in each axis. So flatten the remaining dimensions. + // This also avoids recursion in this function in general. + + // First get the shape minus the 2 inner dimensions + Extent outer_shape; + std::copy_n(lengths.begin(), RANK - 2, outer_shape.begin()); + + NdIter iter(outer_shape); + detail::limited_foreach( + iter.numel(), + config.slice_limit, + [&](auto outer_flat_index) { + // Now decode the outer index and turn it back into a complete index + const auto outer_index = iter(outer_flat_index); + Extent index = {}; + std::copy_n(outer_index.begin(), RANK - 2, index.begin()); + + // Print an extra separating line between two slices + if(outer_flat_index != 0) + stream.msg('\n'); + + // Print an information header about the current slice + stream.msg("Tensor \"", name, "\", slice ["); + for(auto x : outer_index) + stream.msg(x, ", "); + stream.msg(":, :]\n"); + + // And print is as matrix + print_matrix(stream, index); + }, + [&](auto skip_count) { stream.msg("\n(skipping ", skip_count, " slices...)\n"); }); + } + } +}; + +/// @brief Implementation of `TensorPrintStream` to figure out the maximum +/// width of a field. +/// +/// In order to produce neatly aligned tensors, where all values of each row +/// appear on the same columns, we have to figure out the maximum width of +/// each field. This print stream helps with that: It does not actually print +/// anything, it just figures out the maximum width of any value (not message). +/// +/// @details OK, this function does actually print things, but only to an +/// internal `stringstream`. This is so that we can easily figure out the +/// width of the field (in bytes), just by counting the amount of bytes +/// written into the string stream. +/// +/// @see TensorPrintStream +struct MaxFieldWidthStream +{ + size_t max_width = 0; + + /// @brief Print a tensor value to the stream + /// + /// "Print" a value to the stream. This function figures out the width + /// of the value when printed, and then composes it with `max_width` to + /// figure out the total maximum. + /// + /// @param value The value to print. + void val(std::string_view value) { max_width = std::max(max_width, value.size()); } + + /// @brief Print a message to the stream. + /// + /// "Print" a non-value message to the stream. In this implementation, + /// everything is discarded. + /// + /// @tparam Args the types of the values to print. + /// + /// @param args The values to print. + template + void msg([[maybe_unused]] const Args&... args) + { + } +}; + +/// @brief Implementation of `TensorPrintStream` which actually prints. +/// +/// In contrast to `MaxFieldWidthStream`, this function actually prints +/// to an ostream, taking the value produced by that type into account. +struct OutputStream +{ + std::ostream& stream; + // The maximum width of each tensor value. + size_t max_width; + + /// @brief Print a tensor value to the stream + /// + /// Actually print a value into the stream, (right-)padding it to + /// `max_width`. + /// + /// @param value The value to print. + void val(std::string_view value) + { + stream << std::setfill(' ') << std::setw(max_width) << value; + } + + /// @brief Print a message to the stream. + /// + /// This prints a non-value message directly to the ostream, as if + /// folded via `operator<<`. + /// + /// @tparam Args the types of the values to print. + /// + /// @param args The values to print. + template + void msg(const Args&... args) + { + (stream << ... << args); + } +}; + +} // namespace detail + +/// @brief Print device tensor values to an ostream. +/// +/// Print the values of a tensor to an ostream. This function neatly formats +/// the tensor according to `config`, tabulating the values so that they are +/// vertically aligned and skipping values to prevent flooding the terminal. +/// With the default config, this function is good to get a quick overview +/// of what a tensor looks like. For a more complete overview, consider +/// supplying `TensorPrintConfig::unlimited()` to get everything (but beware +/// of flooding the terminal). Tensors are printed with the rightmost-dimension +/// as inner dimension, these values appear on the same row in the output. +/// +/// @tparam DT The data type of the tensor. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param name A name for the tensor. This will be used to add some extra identifying +/// information during printing. +/// @param desc The descriptor for the tensor memory layout. +/// @param d_buffer The tensor's actual data buffer. This is expected to be +/// _device accessible_ memory, as its copied back to the host first. +/// @param config Tensor printing configuration. This allows tweaking some details +/// of the printing process. +/// @param out The ostream to print to, `std::cout` by default. +template +void print_tensor(std::string_view name, + const TensorDescriptor& desc, + const void* d_buffer, + TensorPrintConfig config = {}, + std::ostream& out = std::cout) +{ + // Copy memory to the host (printing from device is sketchy) + const auto space = desc.get_element_space_size_in_bytes(); + std::vector h_buffer(space); + check_hip(hipMemcpy(h_buffer.data(), d_buffer, space, hipMemcpyDeviceToHost)); + + // Create a custom stream with a completely new config (locale, + /// precision, fill, etc). Use an osyncstream to buffer the output + /// while were at it (its not likely to help a lot, but why not). + std::osyncstream stream(out.rdbuf()); + stream.imbue(std::locale(std::locale(), new detail::numpunct{})); + + // Print a header for the entire tensor (regardless of if there are multiple slices). + stream << "Tensor \"" << name << "\": shape = " << desc.get_lengths() << "\n"; + + detail::TensorPrinter printer = { + .name = name, + .config = config, + .lengths = desc.get_lengths(), + .strides = desc.get_strides(), + .h_buffer = h_buffer.data(), + .ss = std::stringstream(), + }; + + // We're actually going to print twice: once to figure out the + // maximum width of the fields, and once to actually print to the stream. + + // Print once to figure out the maximum field width. + detail::MaxFieldWidthStream max_field_width; + printer.print_tensor(max_field_width); + + // Actually print to the output stream. + detail::OutputStream tensor_out = { + .stream = stream, + .max_width = max_field_width.max_width, + }; + printer.print_tensor(tensor_out); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp index 6043ba2103..3f5a9dd465 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp @@ -81,4 +81,15 @@ inline DeviceBuffer alloc_buffer(size_t size) return DeviceBuffer(d_buf); } +/// @brief "Align" an offset to a multiple of a particular alignment. +/// +/// Returns `addr` aligned to the next multiple of `alignment`. +/// +/// @param addr The address to align. +/// @param alignment The alignment. +inline size_t align_fwd(size_t addr, size_t alignment) +{ + return addr % alignment == 0 ? addr : addr - addr % alignment + alignment; +} + } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp index 15fe4d89db..4c99f05c46 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -123,6 +124,33 @@ struct Extent : std::array template Extent(T...) -> Extent; +/// @brief Extent printer +/// +/// This function implements an ostream printing overload for `Extent`, so that +/// they can be printed in the usual `stream << extent` fashion. +/// +/// @tparam RANK Rank (number of spatial dimensions) of the extent. +/// +/// @param stream The stream to print the extent to. +/// @param extent The extent to print to the stream. +template +std::ostream& operator<<(std::ostream& stream, const Extent& extent) +{ + stream << '['; + bool first = true; + for(const auto x : extent) + { + if(first) + first = false; + else + stream << ", "; + + stream << x; + } + + return stream << ']'; +} + /// @brief Concept for automatically deriving tensor memory layout. /// /// A `TensorStridesGenerator` is a type which can be used to automatically diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp index f078a1ac82..28ab954de9 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp @@ -18,6 +18,102 @@ namespace ck_tile::builder::test { +/// @brief Utility structure for N-dimensional iteration using a flat index +/// +/// This structure's main purpose is to "unmerge" a flattened index into a +/// multi-dimensional index, which helps when iterating over multi-dimensional +/// indices without having to write an arbitrary amount of nested for loops. +/// A minimal amount of precomputation must be done to do this efficiently, +/// which is handled in the constructor of this type. +/// +/// @details Decoding a flat index into a multi-dimensional index is done by +/// first computing a reverse scan of the shape. These values can then be +/// used to decode the index in the usual way: +/// +/// x = flat_idx / (size_y * size_z) +/// y = flat_idx % (size_y * size_z) / size_z +/// z = flat_idx % (size_y * size_z) % size_z +/// etc +/// +/// The decode order is such that the innermost dimension (right in +/// the shape extent) changes the fastest. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to +/// iterate. +template +struct NdIter +{ + /// @brief Prepare N-dimensional iteration over a particular shape. + /// + /// Precompute ashape into a form that can be used to easily decode a flat + /// index into a multi-dimensional index. + /// + /// @param shape The shape to iterate over. + explicit NdIter(const Extent& shape) + { + // Precompute shape_scan = [..., shape[-2] * shape[-1], shape[-1], 1] + + numel_ = 1; + for(int i = RANK; i > 0; --i) + { + shape_scan_[i - 1] = numel_; + numel_ *= shape[i - 1]; + } + } + + /// @brief Unflatten a flat index into a multi-dimensional index + /// + /// This applies the usual multi-dimensional indexing method over the + /// precomputed shape scan to get back a multi-dimensional index. + /// The decode order is such that the innermost dimension (right in + /// the shape extent) changes the fastest. + /// + /// @param flat_index The "flattened" (1-dimensional) index of the tensor + /// + /// @returns A multi-dimensional index into the tensor + /// + /// @pre `0 <= flat_index < size()` (in other words, the `flat_index` must + /// be in bounds of the tensor shape that this `NdIter` was made from). + __host__ __device__ Extent operator()(size_t flat_index) const + { + Extent index = {}; + auto idx = flat_index; + for(size_t i = 0; i < RANK; ++i) + { + const auto scanned_dim = shape_scan_[i]; + index[i] = idx / scanned_dim; + idx %= scanned_dim; + } + + return index; + } + + /// @brief Return the total elements to iterate over + /// + /// Get the total number of elements in the shape to iterate over. This value + /// can be used to construct a complete for loop to iterate over all indices + /// of a tensor, for example: + /// + /// for(size_t i = 0; i < iter.numel(); ++i) + /// { + /// const auto index = iter(i); + /// use(index); + /// } + __host__ __device__ size_t numel() const { return numel_; } + + private: + /// Reverse (right) scan of the shape to iterate over. + Extent shape_scan_; + + /// The total number of elements in the shape. This value turns out to be almost + /// always required when iterating over a shape, so just store it in this type + /// so that it is easily accessible. + size_t numel_; +}; + +template +NdIter(Extent) -> NdIter; + /// @brief Concept for constraining tensor iteration functors. /// /// This concept checks that a functor has the correct signature for @@ -50,28 +146,19 @@ constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256; /// @tparam F The type of the callback to invoke. This function must be /// compatible with execution as a __device__ function. /// -/// @param numel The total number of elements in the tensor. -/// @param shape_scan A right-exclusive scan of the shape of the tensor. +/// @param iter An NdIter instance to help iterating over the tensor. /// @param f The callback to invoke for each index of the tensor. This /// functor must be eligible for running on the GPU. template requires ForeachFunctor __global__ __launch_bounds__(BLOCK_SIZE) // - void foreach_kernel(const size_t numel, Extent shape_scan, F f) + void foreach_kernel(NdIter iter, F f) { const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x; - for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE) + for(size_t flat_idx = gid; flat_idx < iter.numel(); flat_idx += gridDim.x * BLOCK_SIZE) { // Compute the current index. - Extent index = {}; - - size_t idx = flat_idx; - for(size_t i = 0; i < RANK; ++i) - { - const auto scanned_dim = shape_scan[i]; - index[i] = idx / scanned_dim; - idx %= scanned_dim; - } + const auto index = iter(flat_idx); // Then invoke the callback with the index. f(index); @@ -160,18 +247,12 @@ void tensor_foreach(const Extent& shape, ForeachFunctor auto f) // order in the kernel is from large-to-small. Right layout is the // easiest solution for that. - Extent shape_scan; - size_t numel = 1; - for(int i = RANK; i > 0; --i) - { - shape_scan[i - 1] = numel; - numel *= shape[i - 1]; - } + NdIter iter(shape); // Reset any errors from previous launches. (void)hipGetLastError(); - kernel<<>>(numel, shape_scan, f); + kernel<<>>(iter, f); check_hip(hipGetLastError()); } @@ -179,7 +260,7 @@ void tensor_foreach(const Extent& shape, ForeachFunctor auto f) /// /// This concept checks that a functor has the correct signature for /// use with the `fill_tensor` function. -template +template concept FillTensorFunctor = requires(const F& f, const Extent& index) { { f(index) } -> std::convertible_to>; }; @@ -199,7 +280,7 @@ concept FillTensorFunctor = requires(const F& f, const Extent& index) { /// @param f A functor used to get the value at a particular coordinate. /// /// @see FillTensorFunctor -template +template void fill_tensor(const TensorDescriptor& desc, void* buffer, FillTensorFunctor auto f) @@ -218,7 +299,7 @@ void fill_tensor(const TensorDescriptor& desc, /// /// This concept checks that a functor has the correct signature for /// use with the `fill_tensor_buffer` function. -template +template concept FillTensorBufferFunctor = requires(const F& f, size_t index) { { f(index) } -> std::convertible_to>; }; @@ -239,7 +320,7 @@ concept FillTensorBufferFunctor = requires(const F& f, size_t index) { /// @param f A functor used to get the value at a particular index. /// /// @see FillTensorBufferFunctor -template +template void fill_tensor_buffer(const TensorDescriptor& desc, void* buffer, FillTensorBufferFunctor
auto f) @@ -247,7 +328,19 @@ void fill_tensor_buffer(const TensorDescriptor& desc, fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); }); } -template +/// @brief Utility for clearing tensor buffers to a particular value. +/// +/// This function initializes all memory backing a particular tensor buffer to +/// one specific value, zero by default. Note that this function ignores strides, +/// and clears the entire buffer backing the tensor. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param value The value to initialize the tensor buffer with. +template void clear_tensor_buffer(const TensorDescriptor& desc, void* buffer, detail::cpp_type_t
value = detail::cpp_type_t
{0}) diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index 609c93cacf..eb16402bc2 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/validation.hpp" /// This file is the main header for the CK-Builder testing system. A high-level @@ -132,8 +134,8 @@ struct Outputs; /// be created using `alloc_inputs()` and that an instance of the corresponding /// `Inputs` structure can be obtained using `.get()`. /// -/// @note The easiest way to implement this type is to use the `DeviceBuffer` -/// type to allocate individual device buffers for each input tensor. +/// @note A default implementation is provided for this type if `Inputs` +/// supports `TensorReflectable`. /// /// @tparam SIGNATURE The signature to specialize the structure for. /// @@ -151,8 +153,8 @@ struct UniqueInputs; /// be created using `alloc_outputs()` and that an instance of the corresponding /// `Outputs` structure can be obtained using `.get()`. /// -/// @note The easiest way to implement this type is to use the `DeviceBuffer` -/// type to allocate individual device buffers for each output tensor. +/// @note A default implementation is provided for this type if `Outputs` +/// supports `TensorReflectable`. /// /// @tparam SIGNATURE The signature to specialize the structure for. /// @@ -197,6 +199,12 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @note A default implementation is provided for this function if `Inputs` +/// supports `TensorReflectable`. +/// /// @tparam SIGNATURE The signature to specialize the structure for. /// /// @param args The run-time arguments of the operation. @@ -207,22 +215,22 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// @see alloc_tensor_buffer() template requires ValidUniqueInputs -UniqueInputs alloc_inputs(const Args& args); +UniqueInputs alloc_inputs(const Args& args) = delete; -/// @brief Allocate inputs corresponding to a signature. +/// @brief Initialize inputs corresponding to a signature. /// /// The `init_inputs()` function is used to initialize pseudo-random data /// to the tensors specified in the Inputs structure. Implementors should /// fill each of the tensors in `inputs` with appropriate random data. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @tparam SIGNATURE the signature to specialize the structure for. /// /// @param args The run-time arguments of the operation. /// @param inputs The operation inputs to initialize with random data. /// -/// @note This function is explicitly deleted to generate compile errors -/// for missing implementations. -/// /// @see Inputs /// @see tensor_initialization template @@ -235,13 +243,16 @@ void init_inputs(const Args& args, Inputs inputs) = delete /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @note A default implementation is provided for this function if `Outputs` +/// supports `TensorReflectable`. +/// /// @tparam SIGNATURE The signature to specialize the structure for. /// /// @param args The run-time arguments of the operation. /// -/// @note This function is explicitly deleted to generate compile errors -/// for missing implementations. -/// /// @see Outputs /// @see UniqueOutputs /// @see alloc_buffer() @@ -262,15 +273,15 @@ UniqueInputs alloc_outputs(const Args& args) = delete; /// were incorrect, and where (a subset of) those elements are located within /// the tensor. See `ValidationReport` for more information about the report. /// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @tparam SIGNATURE The signature to specialize the structure for. /// /// @param args The run-time arguments of the operation. /// @param actual The actual results, the results of the operation to-be-tested. /// @param expected The expected results, the results of the reference implementation. /// -/// @note This function is explicitly deleted to generate compile errors -/// for missing implementations. -/// /// @see ValidationReport template ValidationReport validate(const Args& args, diff --git a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp new file mode 100644 index 0000000000..81d5b7a6f5 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp @@ -0,0 +1,199 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +/// testing.hpp requires developers of a type of SIGNATURE to implement +/// quite a lot of functionality for each SIGNATURE. For example, next +/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define +/// `UniqueInputs`, `UniqueOutputs`, `alloc_inputs`, `alloc_outputs`, +/// and `validate`. The implementation of these latter few functions +/// is usually quite straight forward and adds a bunch of copy-paste +/// overhead. The functionality in this file offers an alternative +/// route: By implementing some reflection functionality in `Inputs` +/// and `Outputs`, we can automatically derive most of the functionality. + +namespace ck_tile::builder::test { + +/// @brief Check whether an `Input` or `Output` struct can be reflected. +/// +/// In order to avoid having to manually redefine a bunch of types related to +/// each `Inputs`/`Outputs` structure, those structures can also provide some +/// "reflection" functionality. To this end, they should implement +/// `static void reflect(const Args args&, auto inspect)`, where `inspect` +/// is called with information about each field in the struct. In more detail, +/// the signature of the `inspect` function is as follows: +/// +/// void inspect( +/// // A human-readable name for the tensor +/// std::string_view name, +/// // Descriptor for the tensor in memory, usually obtained via `args`. +/// const TensorDescriptor& desc, +/// // Member pointer to a field of `T`, which is a GPU-memory pointer +/// // to the relevant tensor memory. +/// void* T::* ptr); +/// +/// Here, `T` is `Inputs` or `Outputs`. +/// +/// @see Inputs +/// @see Outputs +template +concept TensorReflectable = requires(const Args& args) { + { + T::reflect(args, + []([[maybe_unused]] std::string_view name, + // Note: This will be a TensorDescriptor, but the actual + // DT and RANK may differ depending on member. + [[maybe_unused]] const auto& desc, + [[maybe_unused]] void* T::*ptr) {}) + }; +}; + +namespace detail { + +/// The default alignment between tensors allocated separately +/// by `UniqueTensors`. This should be large enough to accomodate +/// any type. hipMalloc returns an alignment of 256 by default. +constexpr size_t TENSOR_ALIGNMENT = 256; + +/// @brief Common type for automatically managing memory of sets of tensors. +/// +/// This type implements the automatic memory management logic for `Inputs` and +/// `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// @tparam Tensors The `Inputs` or `Outputs` type corresponding to `SIGNATURE`. +template + requires TensorReflectable +struct UniqueTensors +{ + /// @brief Allocate tensors. + /// + /// This function computes the total size of memory to allocate according to + /// the tensors in `args`, and then allocates it as a continuous buffer. + /// + /// @param args The run-time arguments of the operation. + explicit UniqueTensors(const Args& args) + { + // First compute the total size of all tensors combined + size_t total_size = 0; + Tensors::reflect(args, + [&, this]([[maybe_unused]] std::string_view name, + const auto& desc, + [[maybe_unused]] void* Tensors::*ptr) { + total_size = align_fwd(total_size, TENSOR_ALIGNMENT); + total_size += desc.get_element_space_size_in_bytes(); + }); + + data_ = alloc_buffer(total_size); + + // Now assign the pointers based on the same offsets that + // we computed in the first loop. + size_t offset = 0; + Tensors::reflect(args, + [&, this]([[maybe_unused]] std::string_view name, + const auto& desc, + [[maybe_unused]] void* Tensors::*ptr) { + offset = align_fwd(offset, TENSOR_ALIGNMENT); + tensors_.*ptr = data_.get() + offset; + offset += desc.get_element_space_size_in_bytes(); + }); + } + + /// @brief Return raw `Inputs` or `Outputs` type. + /// + /// @see ValidUniqueInputs + /// @see ValidUniqueOutputs + Tensors get() const { return tensors_; } + + private: + /// Owning pointer of input memory + DeviceBuffer data_; + /// Struct with pointers to each tensor. Stored here so that we + /// don't need to keep recomputing it. + Tensors tensors_; +}; + +} // namespace detail + +/// @brief Implementation of `UniqueInputs` for `Inputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @see UniqueInputs +template + requires TensorReflectable, SIGNATURE> +struct UniqueInputs : detail::UniqueTensors> +{ + using detail::UniqueTensors>::UniqueTensors; +}; + +/// @brief Implementation of `UniqueOutputs` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @see UniqueOutputs +template + requires TensorReflectable, SIGNATURE> +struct UniqueOutputs : detail::UniqueTensors> +{ + using detail::UniqueTensors>::UniqueTensors; +}; + +/// @brief Implementation of `alloc_inputs` for `Inputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// +/// @see alloc_inputs +template + requires TensorReflectable, SIGNATURE> +UniqueInputs alloc_inputs(const Args& args) +{ + static_assert(ValidUniqueInputs, "sanity check"); + return UniqueInputs(args); +} + +/// @brief Implementation of `alloc_outputs` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// +/// @see alloc_outputs +template + requires TensorReflectable, SIGNATURE> +UniqueOutputs alloc_outputs(const Args& args) +{ + static_assert(ValidUniqueOutputs, "sanity check"); + return UniqueOutputs(args); +} + +/// @brief Implementation of `validate` for `Outputs` that support reflection. +/// +/// @tparam SIGNATURE The signature to specialize for. +/// +/// @param args The run-time arguments of the operation. +/// @param actual The actual results, the results of the operation to-be-tested. +/// @param expected The expected results, the results of the reference implementation. +/// +/// @see alloc_outputs +template + requires TensorReflectable, SIGNATURE> +ValidationReport +validate(const Args& args, Outputs actual, Outputs expected) +{ + ValidationReport report; + + Outputs::reflect( + args, [&](std::string_view name, const auto& desc, void* Outputs::*ptr) { + report.check(name, desc, actual.*ptr, expected.*ptr); + }); + + return report; +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp index 8db0e5d25d..4026642bd0 100644 --- a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp @@ -39,7 +39,7 @@ constexpr size_t data_type_sizeof(DataType data_type) case DataType::FP8: return 1; case DataType::BF8: return 1; case DataType::FP64: return 8; - case DataType::INT32: return 4; + case DataType::I32: return 4; case DataType::I8: return 1; case DataType::I8_I8: return 2; case DataType::U8: return 1; diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 267bf8d2ac..158f271e21 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -7,7 +7,6 @@ #include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/tensor_foreach.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" -#include "ck/library/utility/check_err.hpp" #include "ck/utility/type_convert.hpp" #include #include diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c1c62e91fa..c4cca05e52 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -24,7 +24,7 @@ enum class DataType FP8, BF8, FP64, - INT32, + I32, I8, I8_I8, U8 @@ -192,8 +192,8 @@ enum class TileConvSpecialization FILTER_3x3 }; -// Enums for the forward convolution specialization. -enum class ConvFwdSpecialization +// Enums for the convolution specializations. +enum class ConvSpecialization { DEFAULT, FILTER_1X1_PAD0, @@ -202,22 +202,6 @@ enum class ConvFwdSpecialization ODD_C }; -// Enums for the backward data convolution specialization. -enum class ConvBwdDataSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, -}; - -// Enums for the backward weight convolution specialization. -enum class ConvBwdWeightSpecialization -{ - DEFAULT, - FILTER_1X1_STRIDE1_PAD0, - FILTER_1X1_PAD0, - ODD_C, -}; - // Enums for the Gemm padding. enum class GemmPadding { @@ -249,11 +233,13 @@ enum class PipelineScheduler enum class ConvAlgorithmSpecialization { LARGE_TENSOR, - REFERENCE // GPU reference implementation for validation + REFERENCE, // GPU reference implementation for validation, + TWO_STAGE, + MULTIPLE_D }; -// toString methods for enum classes -inline std::string_view toString(DataType dt) +// to_string methods for enum classes +inline std::string_view to_string(DataType dt) { using enum DataType; switch(dt) @@ -267,7 +253,7 @@ inline std::string_view toString(DataType dt) case FP8: return "FP8"; case BF8: return "BF8"; case FP64: return "FP64"; - case INT32: return "INT32"; + case I32: return "I32"; case I8: return "I8"; case I8_I8: return "I8_I8"; case U8: return "U8"; @@ -276,7 +262,7 @@ inline std::string_view toString(DataType dt) } } -inline std::string_view toString(ConvDirection dir) +inline std::string_view to_string(ConvDirection dir) { using enum ConvDirection; switch(dir) @@ -288,7 +274,7 @@ inline std::string_view toString(ConvDirection dir) } } -inline std::string_view toString(ElementwiseOperation op) +inline std::string_view to_string(ElementwiseOperation op) { using enum ElementwiseOperation; switch(op) @@ -332,7 +318,7 @@ inline std::string_view toString(ElementwiseOperation op) } } -inline std::string_view toString(PipelineVersion ver) +inline std::string_view to_string(PipelineVersion ver) { using enum PipelineVersion; switch(ver) @@ -347,7 +333,7 @@ inline std::string_view toString(PipelineVersion ver) } } -inline std::string_view toString(GemmSpecialization spec) +inline std::string_view to_string(GemmSpecialization spec) { using enum GemmSpecialization; switch(spec) @@ -372,9 +358,9 @@ inline std::string_view toString(GemmSpecialization spec) } } -inline std::string_view toString(ConvFwdSpecialization spec) +inline std::string_view to_string(ConvSpecialization spec) { - using enum ConvFwdSpecialization; + using enum ConvSpecialization; switch(spec) { case DEFAULT: return "DEFAULT"; @@ -386,31 +372,7 @@ inline std::string_view toString(ConvFwdSpecialization spec) } } -inline std::string_view toString(ConvBwdDataSpecialization spec) -{ - using enum ConvBwdDataSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - default: return "Unknown"; - } -} - -inline std::string_view toString(ConvBwdWeightSpecialization spec) -{ - using enum ConvBwdWeightSpecialization; - switch(spec) - { - case DEFAULT: return "DEFAULT"; - case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0"; - case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0"; - case ODD_C: return "ODD_C"; - default: return "Unknown"; - } -} - -inline std::string_view toString(GemmPadding padding) +inline std::string_view to_string(GemmPadding padding) { using enum GemmPadding; switch(padding) @@ -435,7 +397,7 @@ inline std::string_view toString(GemmPadding padding) } } -inline std::string_view toString(PipelineScheduler sched) +inline std::string_view to_string(PipelineScheduler sched) { using enum PipelineScheduler; switch(sched) @@ -447,7 +409,7 @@ inline std::string_view toString(PipelineScheduler sched) } } -inline std::string_view toString(TensorLayout layout) +inline std::string_view to_string(TensorLayout layout) { using enum TensorLayout; switch(layout) @@ -503,63 +465,46 @@ inline std::string_view toString(TensorLayout layout) } // ostream operator overloads for enum classes -inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); } +inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << to_string(dt); } -inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); } +inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) +{ + return os << to_string(dir); +} inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op) { - return os << toString(op); + return os << to_string(op); } inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver) { - return os << toString(ver); + return os << to_string(ver); } inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec) { - return os << toString(spec); + return os << to_string(spec); } -inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec) +inline std::ostream& operator<<(std::ostream& os, ConvSpecialization spec) { - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec) -{ - return os << toString(spec); -} - -inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec) -{ - return os << toString(spec); + return os << to_string(spec); } inline std::ostream& operator<<(std::ostream& os, GemmPadding padding) { - return os << toString(padding); + return os << to_string(padding); } inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched) { - return os << toString(sched); + return os << to_string(sched); } inline std::ostream& operator<<(std::ostream& os, TensorLayout layout) { - return os << toString(layout); -} - -// ostream operator overload for std::variant of convolution specializations -inline std::ostream& operator<<(std::ostream& os, - const std::variant& spec) -{ - std::visit([&os](const auto& s) { os << s; }, spec); - return os; + return os << to_string(layout); } } // namespace ck_tile::builder diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 233eafc366..ddcf8db476 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -83,11 +83,14 @@ add_ck_builder_test(test_ckb_conv_builder unit_tensor_foreach.cpp unit_error.cpp unit_validation.cpp + unit_debug.cpp + unit_conv_fwd_testing.cpp unit_conv_elementwise_op.cpp unit_conv_tensor_layout.cpp unit_conv_tensor_type.cpp unit_conv_thread_block.cpp unit_conv_tuning_params.cpp) +target_link_libraries(test_ckb_conv_builder PRIVATE utility) # Tests the inline diff utility used for comparing strings in tests assertions add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) @@ -121,7 +124,7 @@ add_ck_builder_test(test_ckb_conv_description # Verifies that GetInstanceString() methods and other functions produce valid kernel code. # Tests various convolution types: # - Group convolution (v3, standard, large tensor, WMMA, DL variants) -# - Backward weight group convolution (XDL) +# - Backward weight group convolution (XDL standard, XDL v3, WMMA, DL, multiple D, two-stage variants) # Requires kernel compilation to validate the generated strings through the base class. set(INSTANCE_STRING_TESTS @@ -164,10 +167,35 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp - conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp) + ) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) +set(BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_dl.cpp + conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +) + +if (CK_USE_WMMA) + list(APPEND BWD_WEIGHT_TESTS + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp + conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp + conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp + ) +endif() + +add_ck_builder_test(test_ckb_build_bwd_weight_instances ${BWD_WEIGHT_TESTS}) +target_link_libraries(test_ckb_build_bwd_weight_instances PRIVATE utility) + +add_ck_builder_test(test_ckb_build_bwd_data_instances + conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp + ) +target_link_libraries(test_ckb_build_bwd_data_instances PRIVATE utility) + ################################################################################ # FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) @@ -221,6 +249,8 @@ endforeach() set(CKB_REGRESSION_TESTS test_ckb_instance_string test_ckb_build_fwd_instances + test_ckb_build_bwd_weight_instances + test_ckb_build_bwd_data_instances test_ckb_testing_utils # test_ckb_factory_grouped_convolution_forward_convscale # test_ckb_factory_grouped_convolution_forward_scaleadd_ab diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp new file mode 100644 index 0000000000..584bce2f1b --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_dl.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl{} + .with_thread_block(cku::ThreadBlock_256_128x128x16) + .with_bwd_specialization(cku::ConvSpecialization::DEFAULT) + .with_dl_thread_config(cku::DlThreadConfig_16x1x4x4x1) + .with_dl_thread_cluster(cku::DlThreadCluster_8x2) + .with_dl_transfer(cku::DlTransfer5D); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_DL, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Dl", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..404d1dbacd --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_wmma_cshuffle_v3.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNDHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKZYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNDHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_3DFp16_MultiD_Wmma_ShuffleV3_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "GNDHWC,GKZYXC,GNDHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp new file mode 100644 index 0000000000..206fc8beb9 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_multi_d_xdl_cshuffle.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_MultiD_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16>"}); // check compute types +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..782f33f845 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_wmma_cshuffle_v3.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCHW}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NGKHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_TwoStage_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3", + expected_transfer_parameters, + "Default", + "NGCHW,GKYXC,NGKHW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v1", + "fp16,fp16,2,2>"}); // Check compute types and transpose params. +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp new file mode 100644 index 0000000000..a2a877dbcd --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_two_stage_xdl_cshuffle.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(2) + .with_transpose_params(2, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DBf16_TwoStage_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "Intrawave,v2", // pipeline versions + "bf16,bf16,2,4>"}); // compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp new file mode 100644 index 0000000000..ff350ac804 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCDHW}}, + .weight = {.config = {.layout = GKZYXC}}, + .output = {.config = {.layout = NGKDHW}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) + .with_gridwise_gemm_pipeline(ckb::PipelineVersion::V1); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_3DBf16_Wmma_CShuffle, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffle", + expected_transfer_parameters, + "Default", + "NGCDHW,GKZYXC,NGKDHW", + "PassThrough,PassThrough,PassThrough", + "v1"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..60f7d5bd64 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle_v3.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v1_intrawave) + .with_transpose_params(4, 4); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_Wmma_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; + cku::run_test({"DeviceGroupedConvBwdWeight_Wmma_CShuffleV3", + expected_transfer_parameters, + "Filter1x1Stride1Pad0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v1", + "bf16,bf16,4,4>"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp new file mode 100644 index 0000000000..892f1d35ef --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; + +constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = ckb::TensorLayout::GNHWC}}, + .weight = {.config = {.layout = ckb::TensorLayout::GKYXC}}, + .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; + +constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle{} + .with_thread_block(cku::ThreadBlock_256_128x128x8) + .with_gemm_config(cku::BwdGemmParams_Xdl_4x4_per_wave) + .with_transfer(cku::BwdTransfer_4x64x1) + .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) + .with_transpose_params(2, 2); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_2DFp16_CShuffle_GNHWC, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffle", + expected_transfer_parameters, + "Default", + "GNHWC,GKYXC,GNHWK", + "PassThrough,PassThrough,PassThrough", + "fp16,fp16,2,2>"}); // check compute types and transpose params +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp new file mode 100644 index 0000000000..4ad97209e5 --- /dev/null +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "utils/ckb_conv_test_configs.hpp" +#include "utils/ckb_conv_test_utils.hpp" +#include "utils/conv_algorithm_type_utils.hpp" +#include "ck_tile/host/device_prop.hpp" + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; + +constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::BF16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NGCW}}, + .weight = {.config = {.layout = GKXC}}, + .output = {.config = {.layout = NGKW}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} + .with_thread_block(cku::ThreadBlock_64_32x32x32) + .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) + .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +TEST(BwdWeight_1DBf16_CShuffle_V3, Create) +{ + const auto expected_transfer_parameters = to_string(ALGORITHM); + cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", + expected_transfer_parameters, + "Filter1x1Stride1Pad0", + "NGCW,GKXC,NGKW", + "PassThrough,PassThrough,PassThrough", + "Intrawave", + "v2"}); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp index 284b3929ee..8d85370b26 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_bf16.cpp @@ -30,11 +30,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v2_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp index 6802e0caf8..d3ace110c4 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_fp16.cpp @@ -27,11 +27,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 2, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index 14463bbc17..06d200429c 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -22,18 +22,20 @@ TEST(FwdConvInstances, constexpr ConvSignature FwdConvSignature{.spatial_dim = 1, .direction = FORWARD, .data_type = I8, - .accumulation_data_type = INT32, + .accumulation_data_type = I32, .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{} - .with_thread_block(FwdThreadBlock_128_64x64x64) - .with_gemm_config(FwdGemmParams_Wmma_2x1_per_wave) - .with_transfer(FwdTransfer_4x32x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 0, PipelineScheduler::DEFAULT); + .with_thread_block(ThreadBlock_128_64x64x64) + .with_gemm_config(GemmParams_Wmma_2x1_per_wave) + .with_transfer(Transfer_4x32x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(2) + .with_gridwise_gemm_pipeline(PipelineVersion::V1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp index 4a5618a6b1..610e2fad5f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; @@ -64,10 +64,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_3x3, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_3x3, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v5_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp index 0d9563e05a..23edef5436 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp @@ -32,11 +32,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_64_64x32x32) + .with_thread_block(ThreadBlock_64_64x32x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp index 9bea834ef9..58171cd530 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp @@ -25,15 +25,16 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Default", @@ -59,16 +60,17 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{} - .with_thread_block(FwdThreadBlock_256_128x128x16) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_thread_block(ThreadBlock_256_128x128x16) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_dl_thread_config(DlThreadConfig_16x2x4x4x1) .with_dl_thread_cluster(DlThreadCluster_8x2) - .with_dl_transfer(DlFwdTransfer); + .with_dl_transfer(DlTransfer4D); using Builder = ConvBuilder; const auto expected_transfer_parameters = to_string(FwdConvAlgorithm); + std::cout << "Expected Transfer Parameters: " << expected_transfer_parameters << std::endl; run_test({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", expected_transfer_parameters, "Filter1x1Pad0", diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 1ba811bbe0..3e5e39191e 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -25,11 +25,11 @@ constexpr auto SIGNATURE = .output = {.config = {.layout = ckb::TensorLayout::GNHWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(cku::FwdThreadBlock_256_256x256x32) + .with_thread_block(cku::ThreadBlock_256_256x256x32) .with_gemm_config(cku::FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(cku::FwdTransfer_4x64x1) - .with_specializations(ckb::ConvFwdSpecialization::DEFAULT, - ckb::GemmSpecialization::MNKPadding) + .with_transfer(cku::Transfer_4x64x1) + .with_fwd_specializations(ckb::ConvSpecialization::DEFAULT, + ckb::GemmSpecialization::MNKPadding) .with_block_gemm(cku::BlockGemmDesc_v3_intrawave); using Builder = ckb::ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp index 79ee4915e8..bb35c53ba0 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp32.cpp @@ -26,11 +26,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_STRIDE1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp index 3e3d7e8c2b..b117e693fe 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp8.cpp @@ -27,11 +27,12 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) + .with_thread_block(ThreadBlock_256_256x128x32) .with_gemm_config(FwdGemmParams_Xdl_4x2_per_wave) - .with_transfer(FwdTransfer_4x64x1_fp8) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT); + .with_transfer(Transfer_4x64x1_fp8) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp index 3019c57a18..97bc0a00e5 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp @@ -25,14 +25,13 @@ TEST(FwdConvInstances, .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_256_256x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} + .with_thread_block(ThreadBlock_256_256x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; @@ -62,14 +61,14 @@ TEST( .output = {.config = {.layout = GNHWK}}}; constexpr auto FwdConvAlgorithm = - ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{ - .base_algorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{} - .with_thread_block(FwdThreadBlock_128_128x128x32) - .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x16x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) - .with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)}; + ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{} + .with_thread_block(ThreadBlock_128_128x128x32) + .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) + .with_transfer(Transfer_4x16x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) + .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_num_conv_groups_to_merge(1); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp index 3f9bdfb972..9e6ca00e58 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_bf16.cpp @@ -27,10 +27,10 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v3_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp index b30f958bc4..56d4b8be59 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp16.cpp @@ -27,11 +27,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_128x128x32) + .with_thread_block(ThreadBlock_256_128x128x32) .with_gemm_config(FwdGemmParams_Xdl_2x1_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v4_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp index 33c01c8ac4..df8339241b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_3d_fp32.cpp @@ -27,11 +27,11 @@ TEST(FwdConvInstances, constexpr auto FwdConvAlgorithm = ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{} - .with_thread_block(FwdThreadBlock_256_256x256x32) + .with_thread_block(ThreadBlock_256_256x256x32) .with_gemm_config(FwdGemmParams_Xdl_4x4_per_wave) - .with_transfer(FwdTransfer_4x64x1) - .with_specializations(ConvFwdSpecialization::FILTER_1X1_PAD0, - GemmSpecialization::MNKPadding) + .with_transfer(Transfer_4x64x1) + .with_fwd_specializations(ConvSpecialization::FILTER_1X1_PAD0, + GemmSpecialization::MNKPadding) .with_block_gemm(BlockGemmDesc_v1_intrawave); using Builder = ConvBuilder; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index d5661ad67b..b3a76e4e11 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -101,7 +101,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); @@ -229,7 +229,7 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); @@ -313,7 +313,7 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) // Verify specializations EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp index de2a4fdd14..9d6fab19d1 100644 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp @@ -230,7 +230,7 @@ TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) using Traits = ck_tile::reflect::conv::ConvTraits; - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); } TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) @@ -289,8 +289,7 @@ TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) using Traits = ck_tile::reflect::conv::ConvTraits; - EXPECT_EQ(Traits::conv_specialization, - ck_tile::builder::ConvFwdSpecialization::FILTER_1X1_PAD0); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); } // ============================================================================ diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index 91c75e3e8d..89baf9b51b 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -8,26 +8,27 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_DATA, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdDataConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_DATA, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdDataConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_data", "fp16", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index e2e165967a..292d852b91 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -8,26 +8,27 @@ namespace { using namespace ck_tile::builder::test_utils; -TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) { - constexpr ConvSignature FwdConvSignature{.spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; + constexpr ConvSignature BwdWeightConvSignature{ + .spatial_dim = 2, + .direction = ConvDirection::BACKWARD_WEIGHT, + .data_type = DataType::FP16, + .accumulation_data_type = DataType::FP32, + .input = {.config = {.layout = TensorLayout::NHWGC}}, + .weight = {.config = {.layout = TensorLayout::GKYXC}}, + .output = {.config = {.layout = TensorLayout::NHWGK}}}; - constexpr auto FwdConvAlgorithm = + constexpr auto BwdWeightConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - using Builder = ConvBuilder; + using Builder = ConvBuilder; run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 5ec73d780f..2c35fb5076 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -21,9 +21,9 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 constexpr auto FwdConvAlgorithm = ConvAlgorithm_Tile_GroupedConvolutionKernel{} .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(FwdTileThreadBlock_64x64x64) + .with_tile_thread_block(TileThreadBlock_64x64x64) .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(FwdTileTransfer_4x4x4) + .with_tile_transfer(TileTransfer_4x4x4) .with_tile_optimizations(TileOptimizations{ .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index bf61eb7026..617686fda1 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -28,18 +28,31 @@ struct ThreadBlock }; static_assert(ckb::ThreadBlockDescriptor); -// Describe gridwise XDL GEMM parameters. -struct GridwiseXdlGemm +struct XdlParams { - // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! - size_t ak1 = 0; - size_t bk1 = 0; size_t m_per_xdl = 0; size_t n_per_xdl = 0; size_t m_xdl_per_wave = 0; size_t n_xdl_per_wave = 0; }; -static_assert(ckb::GridwiseXdlGemmDescriptor); +static_assert(ckb::GridwiseXdlGemmDescriptor); + +// Describe gridwise XDL GEMM parameters. +struct GridwiseFwdXdlGemm +{ + // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! + size_t ak1 = 0; + size_t bk1 = 0; + XdlParams xdl_params; +}; +static_assert(ckb::GridwiseFwdXdlGemmDescriptor); + +struct GridwiseBwdXdlGemm +{ + size_t k1 = 0; + XdlParams xdl_params; +}; +static_assert(ckb::GridwiseBwdXdlGemmDescriptor); // Describe gridwise WMMA GEMM parameters. struct GridwiseWmmaGemm @@ -49,25 +62,36 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); -struct BlockGemm +struct BlockGemmPipeline { PipelineVersion pipeline_version; PipelineScheduler scheduler; }; -static_assert(ckb::BlockGemmDescriptor); +static_assert(ckb::BlockGemmPipelineDescriptor); // Describe Aand B block transfer thread cluster lengths. +template struct BlockTransfer { size_t k0; size_t m_n; size_t k1; + size_t k_batch_size; }; -static_assert(ckb::BlockTransferDescriptor); + +// Specialization for ThreadSliceLength == 3 +template <> +struct BlockTransfer<3> +{ + size_t k0; + size_t m_n; + size_t k1; +}; +static_assert(ckb::BlockTransferDescriptor, 3>); +static_assert(ckb::BlockTransferDescriptor, 4>); // Describe C block transfer thread cluster lengths. struct ThreadCluster @@ -97,31 +121,35 @@ struct Epilogue }; static_assert(EpilogueDescriptor); +template struct AccessOrder { - std::array order; + std::array order; }; -static_assert(AccessOrderDescriptor); +static_assert(AccessOrderDescriptor>); +static_assert(AccessOrderDescriptor>); -struct TransferAB +template +struct InputTransfer { - BlockTransfer block_transfer; + BlockTransfer block_transfer; LdsTransfer lds_transfer; - AccessOrder block_transfer_access_order; - AccessOrder src_access_order; + AccessOrder block_transfer_access_order; + AccessOrder src_access_order; }; -struct TransferC +struct OutputTransfer { ThreadCluster thread_cluster_dims; Epilogue epilogue; }; -struct TransferABC +template +struct Transfer { - TransferAB a; - TransferAB b; - TransferC c; + InputTransfer a; + InputTransfer b; + OutputTransfer c; }; // DL-specific descriptors @@ -142,17 +170,19 @@ struct DlThreadCluster }; static_assert(ckb::DlThreadClusterDescriptor); +template struct DlBlockTransfer { - std::array thread_slice_lengths; - std::array thread_cluster_lengths; - std::array thread_cluster_arrange_order; - std::array src_access_order; - std::array src_vector_tensor_lengths; - std::array src_vector_tensor_contiguous_dim_order; - std::array dst_vector_tensor_lengths; + std::array thread_slice_lengths; + std::array thread_cluster_lengths; + std::array thread_cluster_arrange_order; + std::array src_access_order; + std::array src_vector_tensor_lengths; + std::array src_vector_tensor_contiguous_dim_order; + std::array dst_vector_tensor_lengths; }; -static_assert(ckb::DlBlockTransferDescriptor); +static_assert(ckb::DlBlockTransferDescriptor4D>); +static_assert(ckb::DlBlockTransferDescriptor5D>); struct DlEpilogue { @@ -169,9 +199,14 @@ struct ThreadBlock_ ThreadBlock thread_block; }; -struct XdlGemm_ +struct FwdXdlGemm_ { - GridwiseXdlGemm gridwise_gemm; + GridwiseFwdXdlGemm gridwise_gemm; +}; + +struct BwdXdlGemm_ +{ + GridwiseBwdXdlGemm gridwise_gemm; }; struct WmmaGemm_ @@ -179,27 +214,48 @@ struct WmmaGemm_ GridwiseWmmaGemm gridwise_gemm; }; +template struct Transfer_ { - TransferABC transfer; + Transfer transfer; }; -struct ConvSpecialization_ +struct ConvSpecializationFwd_ { - ConvFwdSpecialization fwd_specialization; + ConvSpecialization fwd_specialization; GemmSpecialization gemm_specialization; }; +struct ConvSpecializationBwdWeight_ +{ + ConvSpecialization bwd_weight_specialization; +}; + struct Prefetch_ { size_t num_gemm_k_prefetch_stages; - size_t num_groups_to_merge; PipelineScheduler loop_scheduler; }; +struct TransposeParams_ +{ + size_t max_transpose_transfer_src_scalar_per_vector{1}; + size_t max_transpose_transfer_dst_scalar_per_vector{1}; +}; + +struct GemmBatchOptions_ +{ + size_t num_conv_groups_to_merge{1}; +}; + struct BlockGemm_ { - BlockGemm block_gemm; + BlockGemmPipeline block_gemm_pipeline; +}; + +struct GridGemm_ +{ + PipelineVersion pipeline_version; }; struct DlThreadConfig_ @@ -212,33 +268,34 @@ struct DlThreadCluster_ DlThreadCluster thread_cluster; }; -struct DlBlockTransferAB +template +struct DlTransfer { - DlBlockTransfer block_transfer; -}; - -struct DlBlockTransferC -{ - DlEpilogue epilogue; -}; - -struct DlTransferABC -{ - DlBlockTransferAB a; - DlBlockTransferAB b; - DlBlockTransferC c; + DlBlockTransfer a; + DlBlockTransfer b; + DlEpilogue c; }; +template struct DlTransfer_ { - DlTransferABC transfer; + DlTransfer transfer; }; -// Specialization wrapper for large tensor support -template -struct LargeTensorWrapper +struct TwoStageSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::TWO_STAGE; +}; + +struct MultipleDSpecialization_ +{ + static constexpr ConvAlgorithmSpecialization specialization = + ConvAlgorithmSpecialization::MULTIPLE_D; +}; + +struct LargeTensorSpecialization_ { - BaseAlgorithm base_algorithm; static constexpr ConvAlgorithmSpecialization specialization = ConvAlgorithmSpecialization::LARGE_TENSOR; }; @@ -329,7 +386,11 @@ struct ConvAlgorithmTemplate : Components... constexpr auto with_gemm_config(const GemmConfig& gemm) const { auto result = *this; - if constexpr(std::is_base_of_v) + if constexpr(std::is_base_of_v) + { + result.gridwise_gemm = gemm; + } + else if constexpr(std::is_base_of_v) { result.gridwise_gemm = gemm; } @@ -337,46 +398,82 @@ struct ConvAlgorithmTemplate : Components... { result.gridwise_gemm = gemm; } + else + { + static_assert(false, "Unrecognized GemmConfig type"); + } return result; } template constexpr auto with_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; } - constexpr auto with_specializations(ConvFwdSpecialization fwd_spec, - GemmSpecialization gemm_spec) const + constexpr auto with_fwd_specializations(ConvSpecialization fwd_spec, + GemmSpecialization gemm_spec) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v); auto result = *this; result.fwd_specialization = fwd_spec; result.gemm_specialization = gemm_spec; return result; } - constexpr auto with_prefetch_config(size_t k_prefetch_stages, - size_t groups_to_merge, - PipelineScheduler scheduler) const + constexpr auto with_bwd_specialization(ConvSpecialization bwd_spec) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.bwd_weight_specialization = bwd_spec; + return result; + } + + constexpr auto with_prefetch_config(size_t k_prefetch_stages, PipelineScheduler scheduler) const { static_assert(std::is_base_of_v); auto result = *this; result.num_gemm_k_prefetch_stages = k_prefetch_stages; - result.num_groups_to_merge = groups_to_merge; result.loop_scheduler = scheduler; return result; } + constexpr auto with_transpose_params(size_t max_src_scalar_per_vector, + size_t max_dst_scalar_per_vector) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.max_transpose_transfer_src_scalar_per_vector = max_src_scalar_per_vector; + result.max_transpose_transfer_dst_scalar_per_vector = max_dst_scalar_per_vector; + return result; + } + + constexpr auto with_num_conv_groups_to_merge(size_t num_groups_to_merge) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.num_conv_groups_to_merge = num_groups_to_merge; + return result; + } + template constexpr auto with_block_gemm(const BG& bg) const { static_assert(std::is_base_of_v); - auto result = *this; - result.block_gemm = bg; + auto result = *this; + result.block_gemm_pipeline = bg; + return result; + } + + constexpr auto with_gridwise_gemm_pipeline(const PipelineVersion plv) const + { + static_assert(std::is_base_of_v); + auto result = *this; + result.pipeline_version = plv; return result; } @@ -401,7 +498,8 @@ struct ConvAlgorithmTemplate : Components... template constexpr auto with_dl_transfer(const T& t) const { - static_assert(std::is_base_of_v); + static_assert(std::is_base_of_v, ConvAlgorithmTemplate> || + std::is_base_of_v, ConvAlgorithmTemplate>); auto result = *this; result.transfer = t; return result; @@ -453,26 +551,49 @@ struct ConvAlgorithmTemplate : Components... } }; -// Algorithm types +// Fwd algorithm types using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + BlockGemm_>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle = - ConvAlgorithmTemplate; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + GridGemm_, + Prefetch_, + GemmBatchOptions_>; + using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = ConvAlgorithmTemplate; + DlTransfer_<>>; using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - LargeTensorWrapper; + ConvAlgorithmTemplate, + ConvSpecializationFwd_, + Prefetch_, + GemmBatchOptions_, + LargeTensorSpecialization_>; +// CK Tile algorithm using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + MultipleDSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GridGemm_, + Prefetch_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + MultipleDSpecialization_>; + } // namespace ck_tile::builder::test diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 5d6bc102e6..9e8008ccf0 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -120,14 +120,12 @@ struct DefaultAlgorithm ckb::test::ThreadBlock thread_block{.block_size = 256, .tile_size = {.m = 256, .n = 256, .k = 32}}; - ckb::test::GridwiseXdlGemm gridwise_gemm{.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 16, - .n_per_xdl = 16, - .m_xdl_per_wave = 8, - .n_xdl_per_wave = 8}; + ckb::test::GridwiseFwdXdlGemm gridwise_gemm{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 16, .n_per_xdl = 16, .m_xdl_per_wave = 8, .n_xdl_per_wave = 8}}; - ckb::test::TransferABC transfer{ + ckb::test::Transfer<> transfer{ .a = { .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, @@ -161,10 +159,11 @@ struct DefaultAlgorithm }, }; - ckb::ConvFwdSpecialization fwd_specialization = ckb::ConvFwdSpecialization::DEFAULT; - ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = ckb::PipelineScheduler::INTRAWAVE}; + ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; + ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = + ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/unit_conv_fwd_testing.cpp b/experimental/builder/test/unit_conv_fwd_testing.cpp index 3243935ca5..be95a29a2d 100644 --- a/experimental/builder/test/unit_conv_fwd_testing.cpp +++ b/experimental/builder/test/unit_conv_fwd_testing.cpp @@ -4,6 +4,7 @@ #include "impl/conv_signature_types.hpp" #include "testing_utils.hpp" #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" #include #include #include @@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; using ::testing::ElementsAreArray; +using ::testing::Eq; using ::testing::NotNull; constexpr auto SIGNATURE = @@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs; static_assert(ckt::ValidUniqueInputs); static_assert(ckt::ValidUniqueOutputs); +static_assert(ckt::TensorReflectable); +static_assert(ckt::TensorReflectable); TEST(ConvFwdTesting, MakeDescriptors) { @@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc) EXPECT_THAT(inputs.get().weight, NotNull()); EXPECT_THAT(outputs.get().output, NotNull()); } + +TEST(ConvFwdTesting, Validate) +{ + auto a = alloc_outputs(ARGS); + auto b = alloc_outputs(ARGS); + + // Positive test + { + ckt::Outputs::reflect( + ARGS, + [&]([[maybe_unused]] std::string_view name, + const auto& desc, + void* ckt::Outputs::*ptr) { + ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123}); + ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123}); + }); + + const auto report = ckt::validate(ARGS, a.get(), b.get()); + EXPECT_THAT(report.get_errors().size(), Eq(0)); + } + + // Negative test + { + size_t field_count = 0; + ckt::Outputs::reflect( + ARGS, + [&]([[maybe_unused]] std::string_view name, + const auto& desc, + void* ckt::Outputs::*ptr) { + ++field_count; + ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2}); + ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1}); + }); + + const auto report = ckt::validate(ARGS, a.get(), b.get()); + EXPECT_THAT(report.get_errors().size(), Eq(field_count)); + } +} diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp index ce31f41933..0df94d977e 100644 --- a/experimental/builder/test/unit_conv_tensor_layout.cpp +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -38,11 +38,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -57,11 +57,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -76,11 +76,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) .weight = {.config = {.layout = GKXC}}, .output = {.config = {.layout = GNWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -95,11 +95,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) .weight = {.config = {.layout = GKCX}}, .output = {.config = {.layout = NGKW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -114,11 +114,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -133,11 +133,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = NHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -152,11 +152,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) .weight = {.config = {.layout = GKYXC}}, .output = {.config = {.layout = GNHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -171,11 +171,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) .weight = {.config = {.layout = GKCYX}}, .output = {.config = {.layout = NGKHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -190,11 +190,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) .weight = {.config = {.layout = GKCZYX}}, .output = {.config = {.layout = NGKDHW}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -209,11 +209,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = NDHWGK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -228,11 +228,11 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) .weight = {.config = {.layout = GKZYXC}}, .output = {.config = {.layout = GNDHWK}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); EXPECT_TRUE((std::is_same_v>)); } @@ -273,7 +273,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -287,7 +287,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -301,7 +301,7 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -316,7 +316,7 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}, MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 2); using ExpectedType = @@ -333,7 +333,7 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors) MockAuxiliaryTensorConfig{.layout = GC}, MockAuxiliaryTensorConfig{.layout = G_C_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 3); using ExpectedType = ck::Tuple aux_configs = { MockAuxiliaryTensorConfig{.layout = G_K_strided}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -363,7 +363,7 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution) static constexpr std::array aux_configs = { MockAuxiliaryTensorConfig{.layout = GC}}; - using AuxLayouts = AuxiliaryTensorLayouts; + using AuxLayouts = AuxiliaryTensorLayouts; EXPECT_EQ(AuxLayouts::Size, 1); using ExpectedType = ck::Tuple; @@ -387,11 +387,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -414,11 +414,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -442,11 +442,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALEADD_SCALEADD_RELU}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; @@ -470,11 +470,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); @@ -497,11 +497,11 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias) .operation = OutputOp{.elementwise_operation = ElementwiseOperation::BIAS_BNORM_CLAMP}}}; - using TensorLayouts = ConvTensorLayouts; + using TensorLayouts = ConvTensorLayouts; - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); - EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); using ExpectedDsLayout = ck::Tuple; EXPECT_TRUE((std::is_same_v)); diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index b385210cea..b32ce339fa 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -27,7 +27,7 @@ TEST(ConvTensorType, Exhaustive) case FP32: EXPECT_TRUE((check_same)); break; case FP16: EXPECT_TRUE((check_same)); break; case BF16: EXPECT_TRUE((check_same)); break; - case INT32: EXPECT_TRUE((check_same)); break; + case I32: EXPECT_TRUE((check_same)); break; case FP8: EXPECT_TRUE((check_same)); break; case I8: EXPECT_TRUE((check_same)); break; case U8: EXPECT_TRUE((check_same)); break; diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index b35a1ced55..9005742930 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm; + } block_gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - struct GridwiseGemm - { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; - } gridwise_gemm; + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); @@ -78,8 +75,8 @@ TEST(ConvTuningParams, AssignsFwdConvSpecialization) { constexpr struct Algorithm { - ckb::ConvFwdSpecialization fwd_specialization = - ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + ckb::ConvSpecialization fwd_specialization = + ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0; } kAlgorithm; constexpr auto conv_spec = SetFwdConvSpecialization(); diff --git a/experimental/builder/test/unit_debug.cpp b/experimental/builder/test/unit_debug.cpp new file mode 100644 index 0000000000..80ff291782 --- /dev/null +++ b/experimental/builder/test/unit_debug.cpp @@ -0,0 +1,464 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/testing/debug.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using ck_tile::test::StringEqWithDiff; +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Gt; + +TEST(Debug, PrintDescriptor) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{10, 11, 12}, ckt::PackedRightLayout{}); + + std::stringstream ss; + ckt::print_descriptor("test", desc, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Descriptor \"test\":\n" + " data type: I32\n" + " size: 1'320 elements\n" + " space: 1'320 elements (5'280 bytes)\n" + " lengths: [10, 11, 12]\n" + " strides: [132, 12, 1]\n" + " packed: yes\n")); + + // Make sure that the stream locale does not leak. + ss.str(""); + ss << 1000; + EXPECT_THAT(ss.str(), StringEqWithDiff("1000")); +} + +TEST(Debug, LimitedForeach) +{ + { + std::vector values; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 10, + 2, + [&](auto i) { values.push_back(i); }, + [&](auto skip_count) { + ++delim_count; + EXPECT_THAT(skip_count, Eq(10 - 2)); + }); + EXPECT_THAT(values, ElementsAreArray({0, 9})); + EXPECT_THAT(delim_count, Eq(1)); + } + + { + std::vector values; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 100, + 9, + [&](auto i) { values.push_back(i); }, + [&](auto skip_count) { + ++delim_count; + EXPECT_THAT(skip_count, Eq(100 - 9)); + }); + EXPECT_THAT(values, ElementsAreArray({0, 1, 2, 3, 4, 96, 97, 98, 99})); + EXPECT_THAT(delim_count, Eq(1)); + } + + { + size_t call_count = 0; + size_t delim_count = 0; + ckt::detail::limited_foreach( + 50, + 100, + [&](auto i) { + EXPECT_THAT(i, Eq(call_count)); + ++call_count; + }, + [&]([[maybe_unused]] auto skip_count) { ++delim_count; }); + EXPECT_THAT(call_count, Eq(50)); + EXPECT_THAT(delim_count, Eq(0)); + } +} + +TEST(Debug, PrintTensor0D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 123; }); + + std::stringstream ss; + ckt::print_tensor("0D", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"0D\": shape = []\n" + " 123\n")); +} + +TEST(Debug, PrintTensor1D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{44}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i % 7; }); + + std::stringstream ss; + ckt::print_tensor("1D", desc, a.get(), {}, ss); + + // Note: output does not involve the size of the matrix separator fields, + // since these are not printed. + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"1D\": shape = [44]\n" + " 0 1 2 3 4 ... 4 5 6 0 1\n")); +} + +TEST(Debug, PrintTensor4D) +{ + auto desc = ckt::make_descriptor(ckt::Extent{100, 110, 120, 130}, + ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i; }); + + std::stringstream ss; + ckt::print_tensor("4D", + desc, + a.get(), + { + // Reduce default limits to have smaller output here. + // That also tests that we can configure these (to some + // extent). + .col_limit = 4, + .row_limit = 4, + .slice_limit = 4, + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"4D\": shape = [100, 110, 120, 130]\n" + "Tensor \"4D\", slice [0, 0, :, :]\n" + " 0 1 ... 128 129\n" + " 130 131 ... 258 259\n" + " ... ... ... ... ...\n" + " 15340 15341 ... 15468 15469\n" + " 15470 15471 ... 15598 15599\n" + "\n" + "Tensor \"4D\", slice [0, 1, :, :]\n" + " 15600 15601 ... 15728 15729\n" + " 15730 15731 ... 15858 15859\n" + " ... ... ... ... ...\n" + " 30940 30941 ... 31068 31069\n" + " 31070 31071 ... 31198 31199\n" + "\n" + "(skipping 10'996 slices...)\n" + "\n" + "Tensor \"4D\", slice [99, 108, :, :]\n" + " 171568800 171568801 ... 171568928 171568929\n" + " 171568930 171568931 ... 171569058 171569059\n" + " ... ... ... ... ...\n" + " 171584140 171584141 ... 171584268 171584269\n" + " 171584270 171584271 ... 171584398 171584399\n" + "\n" + "Tensor \"4D\", slice [99, 109, :, :]\n" + " 171584400 171584401 ... 171584528 171584529\n" + " 171584530 171584531 ... 171584658 171584659\n" + " ... ... ... ... ...\n" + " 171599740 171599741 ... 171599868 171599869\n" + " 171599870 171599871 ... 171599998 171599999\n")); +} + +TEST(Debug, PrintTensorCustomConfig) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{10, 10, 10}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 101 % 77; }); + + std::stringstream ss; + ckt::print_tensor("CustomConfig", + desc, + a.get(), + { + // Reduce default limits to have smaller output here. + // That also tests that we can configure these. + .col_limit = 4, + .row_limit = 2, + .slice_limit = 6, + // Try with different sizes to make sure that the alignment + // is still correct after changing these. + .row_prefix = ">>>>", + .row_field_sep = "|||||", + .row_skip_val = "-------", + .matrix_row_skip_val = "&&&&&&&&", + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"CustomConfig\": shape = [10, 10, 10]\n" + "Tensor \"CustomConfig\", slice [0, :, :]\n" + ">>>>||||| 0||||| 24|||||-------||||| 38||||| 62\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 4||||| 28|||||-------||||| 42||||| 66\n" + "\n" + "Tensor \"CustomConfig\", slice [1, :, :]\n" + ">>>>||||| 13||||| 37|||||-------||||| 51||||| 75\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 17||||| 41|||||-------||||| 55||||| 2\n" + "\n" + "Tensor \"CustomConfig\", slice [2, :, :]\n" + ">>>>||||| 26||||| 50|||||-------||||| 64||||| 11\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 30||||| 54|||||-------||||| 68||||| 15\n" + "\n" + "(skipping 4 slices...)\n" + "\n" + "Tensor \"CustomConfig\", slice [7, :, :]\n" + ">>>>||||| 14||||| 38|||||-------||||| 52||||| 76\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 18||||| 42|||||-------||||| 56||||| 3\n" + "\n" + "Tensor \"CustomConfig\", slice [8, :, :]\n" + ">>>>||||| 27||||| 51|||||-------||||| 65||||| 12\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 31||||| 55|||||-------||||| 69||||| 16\n" + "\n" + "Tensor \"CustomConfig\", slice [9, :, :]\n" + ">>>>||||| 40||||| 64|||||-------||||| 1||||| 25\n" + ">>>>|||||&&&&&&&&|||||&&&&&&&&|||||-------|||||&&&&&&&&|||||&&&&&&&&\n" + ">>>>||||| 44||||| 68|||||-------||||| 5||||| 29\n")); +} + +TEST(Debug, PrintTensorUnlimitedMatrix) +{ + // To limit the output of the test, split the "unlimited" test up into one for the + // matrices and one for the slices. + + const ckt::Extent shape = ckt::Extent{12, 12}; + const ckt::TensorPrintConfig default_config; + + // The shape should be larger than the default, otherwise this test doesn't make + // any sense. + ASSERT_THAT(shape[1], Gt(default_config.col_limit)); + ASSERT_THAT(shape[2], Gt(default_config.row_limit)); + + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i ^ 0xF; }); + + std::stringstream ss; + ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"UnlimitedConfig\": shape = [12, 12]\n" + " 15 14 13 12 11 10 9 8 7 6 5 4\n" + " 3 2 1 0 31 30 29 28 27 26 25 24\n" + " 23 22 21 20 19 18 17 16 47 46 45 44\n" + " 43 42 41 40 39 38 37 36 35 34 33 32\n" + " 63 62 61 60 59 58 57 56 55 54 53 52\n" + " 51 50 49 48 79 78 77 76 75 74 73 72\n" + " 71 70 69 68 67 66 65 64 95 94 93 92\n" + " 91 90 89 88 87 86 85 84 83 82 81 80\n" + " 111 110 109 108 107 106 105 104 103 102 101 100\n" + " 99 98 97 96 127 126 125 124 123 122 121 120\n" + " 119 118 117 116 115 114 113 112 143 142 141 140\n" + " 139 138 137 136 135 134 133 132 131 130 129 128\n")); +} + +TEST(Debug, PrintTensorUnlimitedSlices) +{ + // To limit the output of the test, split the "unlimited" test up into one for the + // matrices and one for the slices. + + const ckt::Extent shape = ckt::Extent{13, 1, 1}; + const ckt::TensorPrintConfig default_config; + + // The shape should be larger than the default, otherwise this test doesn't make + // any sense. + ASSERT_THAT(shape[0], Gt(default_config.slice_limit)); + + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return i * 3; }); + + std::stringstream ss; + ckt::print_tensor("UnlimitedConfig", desc, a.get(), ckt::TensorPrintConfig::unlimited(), ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"UnlimitedConfig\": shape = [13, 1, 1]\n" + "Tensor \"UnlimitedConfig\", slice [0, :, :]\n" + " 0\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [1, :, :]\n" + " 3\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [2, :, :]\n" + " 6\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [3, :, :]\n" + " 9\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [4, :, :]\n" + " 12\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [5, :, :]\n" + " 15\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [6, :, :]\n" + " 18\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [7, :, :]\n" + " 21\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [8, :, :]\n" + " 24\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [9, :, :]\n" + " 27\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [10, :, :]\n" + " 30\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [11, :, :]\n" + " 33\n" + "\n" + "Tensor \"UnlimitedConfig\", slice [12, :, :]\n" + " 36\n")); +} + +TEST(Debug, PrintTensorFP32) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(1.9999, i); }); + + std::stringstream ss; + ckt::print_tensor("FP32", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FP32\": shape = [5, 5]\n" + " 1.000 2.000 4.000 7.999 15.997\n" + " 31.992 63.981 127.955 255.898 511.770\n" + " 1023.488 2046.874 4093.543 8186.677 16372.535\n" + " 32743.432 65483.590 130960.633 261908.172 523790.156\n" + " 1047527.938 2094951.125 4189692.750 8378966.500 16757095.000\n")); +} + +TEST(Debug, PrintTensorBF16) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(1.2345678f * i); }); + + std::stringstream ss; + ckt::print_tensor("BF16", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"BF16\": shape = [5, 5]\n" + " 0.000 1.234 2.469 3.703 4.938\n" + " 6.188 7.406 8.625 9.875 11.125\n" + " 12.375 13.562 14.812 16.000 17.250\n" + " 18.500 19.750 21.000 22.250 23.500\n" + " 24.750 25.875 27.125 28.375 29.625\n")); +} + +TEST(Debug, PrintTensorFP8) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(i * 0.1f); }); + + std::stringstream ss; + ckt::print_tensor("FP8", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FP8\": shape = [5, 5]\n" + " 0.000 0.102 0.203 0.312 0.406\n" + " 0.500 0.625 0.688 0.812 0.875\n" + " 1.000 1.125 1.250 1.250 1.375\n" + " 1.500 1.625 1.750 1.750 1.875\n" + " 2.000 2.000 2.250 2.250 2.500\n")); +} + +TEST(Debug, PrintTensorSpecialFloats) +{ + auto desc = + ckt::make_descriptor(ckt::Extent{5, 5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { + if(i % 8 == 1) + return 0.f / 0.f; + else if(i % 7 == 1) + return std::sqrt(-1.f); + else if(i % 6 == 1) + return 1.f / 0.f; + else if(i % 5 == 1) + return -1.f / 0.f; + else + return static_cast(i); + }); + + std::stringstream ss; + ckt::print_tensor("specials", desc, a.get(), {}, ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"specials\": shape = [5, 5]\n" + " 0.000 nan 2.000 3.000 4.000\n" + " 5.000 -inf inf -nan nan\n" + " 10.000 -inf 12.000 inf 14.000\n" + " -nan -inf nan 18.000 inf\n" + " 20.000 -inf -nan 23.000 24.000\n")); +} + +TEST(Debug, PrintTensorFloatPrecision) +{ + auto desc = ckt::make_descriptor(ckt::Extent{5}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return std::pow(0.9, i); }); + + std::stringstream ss; + ckt::print_tensor("FloatPrecision", + desc, + a.get(), + { + .float_precision = 10, + }, + ss); + + EXPECT_THAT(ss.str(), + StringEqWithDiff( // + "Tensor \"FloatPrecision\": shape = [5]\n" + " 1.0000000000 0.8999999762 0.8100000024 0.7289999723 0.6560999751\n")); +} diff --git a/experimental/builder/test/unit_device_buffer.cpp b/experimental/builder/test/unit_device_buffer.cpp index c7180395b7..548b055238 100644 --- a/experimental/builder/test/unit_device_buffer.cpp +++ b/experimental/builder/test/unit_device_buffer.cpp @@ -88,3 +88,11 @@ TEST(DeviceBuffer, AllocTensorBuffer) EXPECT_THAT(hipMemset(buffer.get(), 0xFF, descriptor.get_element_space_size_in_bytes()), HipSuccess()); } + +TEST(DeviceBuffer, AlignForward) +{ + EXPECT_THAT(ckt::align_fwd(24, 8), Eq(24)); + EXPECT_THAT(ckt::align_fwd(25, 8), Eq(32)); + EXPECT_THAT(ckt::align_fwd(0xd7c563, 0x1000), Eq(0xd7d000)); + EXPECT_THAT(ckt::align_fwd(19561, 23), Eq(19573)); +} diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index 672ebbd88a..ce6209795a 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -6,11 +6,13 @@ #include #include #include +#include #include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; +using ck_tile::test::StringEqWithDiff; using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::Throws; @@ -76,7 +78,7 @@ TEST(TensorDescriptor, MakeDescriptor) // Note: automatic inference of RANK. const auto desc = - ckt::make_descriptor(lengths, ckt::PackedRightLayout{}); + ckt::make_descriptor(lengths, ckt::PackedRightLayout{}); EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); EXPECT_THAT(desc.get_strides(), @@ -173,7 +175,7 @@ TEST(TensorDescriptor, ExtentFromVector) TEST(TensorDescriptor, IsPacked) { - constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test + constexpr auto dt = ckb::DataType::I32; // Irrelevant for this test EXPECT_TRUE( ckt::make_descriptor
(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{}) .is_packed()); @@ -189,3 +191,20 @@ TEST(TensorDescriptor, IsPacked) EXPECT_FALSE( ckt::make_descriptor
(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed()); } + +TEST(TensorDescriptor, PrintExtent) +{ + { + const ckt::Extent extent{6233, 55, 1235, 52, 203}; + std::stringstream ss; + ss << extent; + EXPECT_THAT(ss.str(), StringEqWithDiff("[6233, 55, 1235, 52, 203]")); + } + + { + const ckt::Extent extent{}; + std::stringstream ss; + ss << extent; + EXPECT_THAT(ss.str(), StringEqWithDiff("[]")); + } +} diff --git a/experimental/builder/test/unit_tensor_foreach.cpp b/experimental/builder/test/unit_tensor_foreach.cpp index de635bc09b..f689d3c82f 100644 --- a/experimental/builder/test/unit_tensor_foreach.cpp +++ b/experimental/builder/test/unit_tensor_foreach.cpp @@ -16,6 +16,28 @@ namespace ckt = ck_tile::builder::test; using ::testing::Each; using ::testing::Eq; +TEST(TensorForeach, NdIter) +{ + { + ckt::NdIter iter(ckt::Extent{523, 345, 123, 601}); + + EXPECT_THAT(iter.numel(), Eq(13'338'296'505ULL)); + EXPECT_THAT(iter(0), Eq(ckt::Extent{0, 0, 0, 0})); + EXPECT_THAT(iter(1), Eq(ckt::Extent{0, 0, 0, 1})); + EXPECT_THAT(iter(601), Eq(ckt::Extent{0, 0, 1, 0})); + EXPECT_THAT(iter(601 * 123), Eq(ckt::Extent{0, 1, 0, 0})); + EXPECT_THAT(iter(601 * 123 * 10), Eq(ckt::Extent{0, 10, 0, 0})); + EXPECT_THAT(iter(((34 * 345 + 63) * 123 + 70) * 601 + 5), Eq(ckt::Extent{34, 63, 70, 5})); + } + + { + ckt::NdIter iter(ckt::Extent{}); + + EXPECT_THAT(iter.numel(), Eq(1)); + EXPECT_THAT(iter(0), Eq(ckt::Extent{})); + } +} + TEST(TensorForeach, CalculateOffset) { EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123)); @@ -87,8 +109,8 @@ TEST(TensorForeach, VisitsEveryIndex) TEST(TensorForeach, FillTensorBuffer) { - auto desc = ckt::make_descriptor(ckt::Extent{31, 54, 13}, - ckt::PackedRightLayout{}); + auto desc = + ckt::make_descriptor(ckt::Extent{31, 54, 13}, ckt::PackedRightLayout{}); auto buffer = ckt::alloc_tensor_buffer(desc); @@ -109,7 +131,7 @@ TEST(TensorForeach, FillTensor) // FillTensor with non-packed indices should not write out-of-bounds. const ckt::Extent shape = {4, 23, 35}; const ckt::Extent pad = {12, 53, 100}; - auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); const auto strides = desc.get_strides(); auto size = desc.get_element_space_size(); @@ -169,7 +191,7 @@ TEST(TensorForeach, ClearTensorZeros) const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6}; const auto desc = - ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); auto buffer = ckt::alloc_tensor_buffer(desc); ckt::clear_tensor_buffer(desc, buffer.get()); diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp index 5f6b620d6b..a83d034ac2 100644 --- a/experimental/builder/test/unit_validation.cpp +++ b/experimental/builder/test/unit_validation.cpp @@ -173,8 +173,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect) } { - auto desc = ckt::make_descriptor({'G', 'P', 'U'}, - ckt::PackedRightLayout{}); + auto desc = + ckt::make_descriptor({'G', 'P', 'U'}, ckt::PackedRightLayout{}); auto a = ckt::alloc_tensor_buffer(desc); auto b = ckt::alloc_tensor_buffer(desc); @@ -204,6 +204,7 @@ struct DummySignature constexpr DummySignature DUMMY_SIGNATURE = {}; namespace ck_tile::builder::test { + template <> struct Args { @@ -225,6 +226,7 @@ struct Outputs void* b; }; +// Explicitly implement validate for this type to test that that works. template <> ValidationReport validate(const Args& args, Outputs actual, diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index ad5a5f4f6f..3b83ead2d0 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -15,31 +15,42 @@ using namespace test; constexpr DlThreadConfig DlThreadConfig_16x2x4x4x1{ .k0_per_block = 16, .k1 = 2, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; +constexpr DlThreadConfig DlThreadConfig_16x1x4x4x1{ + .k0_per_block = 16, .k1 = 1, .m1_per_thread = 4, .n1_per_thread = 4, .k_per_thread = 1}; + constexpr DlThreadCluster DlThreadCluster_8x2{.m1_xs = {8, 2}, .n1_xs = {8, 2}}; -constexpr DlBlockTransfer DlBlockTransferAB{.thread_slice_lengths = {8, 1, 1, 2}, - .thread_cluster_lengths = {2, 1, 128, 1}, - .thread_cluster_arrange_order = {1, 2, 0, 3}, - .src_access_order = {1, 2, 0, 3}, - .src_vector_tensor_lengths = {4, 1, 1, 2}, - .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, - .dst_vector_tensor_lengths = {1, 1, 1, 2}}; +constexpr DlBlockTransfer<4> DlBlockTransfer_8x1x1x2{ + .thread_slice_lengths = {8, 1, 1, 2}, + .thread_cluster_lengths = {2, 1, 128, 1}, + .thread_cluster_arrange_order = {1, 2, 0, 3}, + .src_access_order = {1, 2, 0, 3}, + .src_vector_tensor_lengths = {4, 1, 1, 2}, + .src_vector_tensor_contiguous_dim_order = {1, 2, 0, 3}, + .dst_vector_tensor_lengths = {1, 1, 1, 2}}; -constexpr DlTransferABC DlFwdTransfer{.a = - { - .block_transfer = DlBlockTransferAB, - }, - .b = - { - .block_transfer = DlBlockTransferAB, - }, - .c = { - .epilogue = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, - .src_dst_vector_dim = 5, - .dst_scalar_per_vector = 4}, - }}; +constexpr DlTransfer<4> DlTransfer4D{.a = DlBlockTransfer_8x1x1x2, + .b = DlBlockTransfer_8x1x1x2, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 4}}; -constexpr TransferABC FwdTransfer_4x64x1{ +constexpr DlBlockTransfer<5> DlBlockTransfer_1x8x1x1x1{ + .thread_slice_lengths = {1, 8, 1, 1, 1}, + .thread_cluster_lengths = {1, 2, 1, 128, 1}, + .thread_cluster_arrange_order = {0, 2, 3, 1, 4}, + .src_access_order = {0, 2, 3, 1, 4}, + .src_vector_tensor_lengths = {1, 1, 1, 1, 1}, + .src_vector_tensor_contiguous_dim_order = {0, 2, 3, 1, 4}, + .dst_vector_tensor_lengths = {1, 1, 1, 1, 1}}; + +constexpr DlTransfer<5> DlTransfer5D{.a = DlBlockTransfer_1x8x1x1x1, + .b = DlBlockTransfer_1x8x1x1x1, + .c = {.src_dst_access_order = {0, 1, 2, 3, 4, 5}, + .src_dst_vector_dim = 5, + .dst_scalar_per_vector = 1}}; + +constexpr Transfer<> Transfer_4x64x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -72,7 +83,73 @@ constexpr TransferABC FwdTransfer_4x64x1{ }, }; -constexpr TransferABC FwdTransfer_4x64x1_fp8{ +constexpr Transfer<4> BwdTransfer_4x64x1{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1, .k_batch_size = 1}, + .lds_transfer = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 4, + .is_direct_load = false, + .lds_padding = true}, + .block_transfer_access_order = {0, 3, 1, 2}, + .src_access_order = {0, 2, 1, 3}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + }, +}; + +constexpr Transfer<> BwdTransfer_4x8x1_4x16x1_v3{ + .a = + { + .block_transfer = {.k0 = 4, .m_n = 8, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .b = + { + .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, + .lds_transfer = {.src_vector_dim = 1, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, + .lds_padding = false}, + .block_transfer_access_order = {2, 0, 1}, + .src_access_order = {1, 0, 2}, + }, + .c = + { + .thread_cluster_dims = + {.m_block = 1, .m_wave_per_xdl = 8, .n_block = 1, .n_wave_per_xdl = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, + }, +}; + +constexpr Transfer<> Transfer_4x64x1_fp8{ .a = { .block_transfer = {.k0 = 4, .m_n = 64, .k1 = 1}, @@ -105,7 +182,7 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{ }, }; -constexpr TransferABC FwdTransfer_4x16x1{ +constexpr Transfer<> Transfer_4x16x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 16, .k1 = 1}, @@ -139,7 +216,7 @@ constexpr TransferABC FwdTransfer_4x16x1{ }, }; -constexpr TransferABC FwdTransfer_4x32x1{ +constexpr Transfer<> Transfer_4x32x1{ .a = { .block_transfer = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -172,59 +249,80 @@ constexpr TransferABC FwdTransfer_4x32x1{ }, }; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_4x4_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}; +constexpr GridwiseBwdXdlGemm BwdGemmParams_Xdl_1x1_per_wave{ + .k1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 1, .n_xdl_per_wave = 1}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ - .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_4x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 2}}; -constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = PipelineVersion::V1}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x2_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}}; -constexpr ThreadBlock FwdThreadBlock_256_256x256x32{.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr GridwiseFwdXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, + .bk1 = 8, + .xdl_params = {.m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}}; -constexpr ThreadBlock FwdThreadBlock_256_256x128x32{.block_size = 256, - .tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 32, .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x32{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr GridwiseWmmaGemm GemmParams_Wmma_16x16_2x1_per_wave{ + .k1 = 8, .m_per_wmma = 16, .n_per_wmma = 16, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1}; -constexpr ThreadBlock FwdThreadBlock_256_128x128x16{.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr ThreadBlock ThreadBlock_256_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_64_64x32x32{.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_256x128x32{.block_size = 256, + .tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_128x128x32{.block_size = 128, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr ThreadBlock ThreadBlock_256_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128_64x64x64{.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr ThreadBlock ThreadBlock_256_128x128x16{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = PipelineVersion::V1, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_256_128x128x8{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 8}}; -constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = PipelineVersion::V2, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_64_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = PipelineVersion::V3, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_64_32x32x32{.block_size = 64, + .tile_size = {.m = 32, .n = 32, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = PipelineVersion::V4, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_128_128x128x32{.block_size = 128, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = PipelineVersion::V5, - .scheduler = PipelineScheduler::INTRAWAVE}; +constexpr ThreadBlock ThreadBlock_128_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; + +constexpr BlockGemmPipeline BlockGemmDesc_v1_intrawave = { + .pipeline_version = PipelineVersion::V1, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v2_intrawave = { + .pipeline_version = PipelineVersion::V2, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v3_intrawave = { + .pipeline_version = PipelineVersion::V3, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v4_intrawave = { + .pipeline_version = PipelineVersion::V4, .scheduler = PipelineScheduler::INTRAWAVE}; + +constexpr BlockGemmPipeline BlockGemmDesc_v5_intrawave = { + .pipeline_version = PipelineVersion::V5, .scheduler = PipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp index 377234dd19..41a1250854 100644 --- a/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_tile_test_configs.hpp @@ -12,35 +12,35 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -constexpr TileTransfer FwdTileTransfer_1x1x1{ +constexpr TileTransfer TileTransfer_1x1x1{ .a_scalar_per_vector = 1, .b_scalar_per_vector = 1, .c_scalar_per_vector = 1, }; -constexpr TileTransfer FwdTileTransfer_4x4x4{ +constexpr TileTransfer TileTransfer_4x4x4{ .a_scalar_per_vector = 4, .b_scalar_per_vector = 4, .c_scalar_per_vector = 4, }; -constexpr TileTransfer FwdTileTransfer_8x8x8{ +constexpr TileTransfer TileTransfer_8x8x8{ .a_scalar_per_vector = 8, .b_scalar_per_vector = 8, .c_scalar_per_vector = 8, }; -constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; +constexpr TileThreadBlock TileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr TileThreadBlock TileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr TileThreadBlock TileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}}; constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = { .warps = {.m = 2, .n = 2, .k = 1}, diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index e4db149a98..23f4cf3364 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -54,7 +54,7 @@ inline std::string to_string(PipelineScheduler t) } template <> -inline std::string to_string(ConvFwdSpecialization t) +inline std::string to_string(ConvSpecialization t) { std::ostringstream oss; oss << t; @@ -86,11 +86,20 @@ inline std::string to_string(ThreadBlock t) } template <> -inline std::string to_string(GridwiseXdlGemm t) +inline std::string to_string(GridwiseBwdXdlGemm t) { std::ostringstream oss; - oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << "," - << t.m_xdl_per_wave << "," << t.n_xdl_per_wave; + oss << t.k1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl << "," + << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; + return oss.str(); +} + +template <> +inline std::string to_string(GridwiseFwdXdlGemm t) +{ + std::ostringstream oss; + oss << t.ak1 << "," << t.bk1 << "," << t.xdl_params.m_per_xdl << "," << t.xdl_params.n_per_xdl + << "," << t.xdl_params.m_xdl_per_wave << "," << t.xdl_params.n_xdl_per_wave; return oss.str(); } @@ -104,17 +113,29 @@ inline std::string to_string(GridwiseWmmaGemm t) } template <> -inline std::string to_string(BlockGemm t) +inline std::string to_string(BlockGemmPipeline t) { std::ostringstream oss; oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version); return oss.str(); } -template <> -inline std::string to_string(BlockTransfer t) +template +inline std::string to_string(BlockTransfer t) { - return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + if constexpr(ThreadClusterRank == 4) + { + return array_to_seq(std::array{t.k_batch_size, t.k0, t.m_n, t.k1}); + } + else if constexpr(ThreadClusterRank == 3) + { + return array_to_seq(std::array{t.k0, t.m_n, t.k1}); + } + else + { + static_assert(ThreadClusterRank == 3 || ThreadClusterRank == 4, + "Unsupported ThreadClusterRank"); + } } template <> @@ -134,14 +155,14 @@ inline std::string to_string(LdsTransfer t) return oss.str(); } -template <> -inline std::string to_string(AccessOrder t) +template +inline std::string to_string(AccessOrder t) { return array_to_seq(t.order); } -template <> -inline std::string to_string(TransferAB t) +template +inline std::string to_string(InputTransfer t) { std::ostringstream oss; oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << "," @@ -152,7 +173,7 @@ inline std::string to_string(TransferAB t) } template <> -inline std::string to_string(TransferC t) +inline std::string to_string(OutputTransfer t) { std::ostringstream oss; oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << "," @@ -160,8 +181,8 @@ inline std::string to_string(TransferC t) return oss.str(); } -template <> -inline std::string to_string(TransferABC t) +template +inline std::string to_string(Transfer t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -185,7 +206,19 @@ inline std::string to_string(DlThreadCluster t) } template <> -inline std::string to_string(DlBlockTransfer t) +inline std::string to_string>(DlBlockTransfer<4> t) +{ + std::ostringstream oss; + oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) + << "," << array_to_seq(t.thread_cluster_arrange_order) << "," + << array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths) + << "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << "," + << array_to_seq(t.dst_vector_tensor_lengths); + return oss.str(); +} + +template <> +inline std::string to_string>(DlBlockTransfer<5> t) { std::ostringstream oss; oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths) @@ -206,19 +239,24 @@ inline std::string to_string(DlEpilogue t) } template <> -inline std::string to_string(DlBlockTransferAB t) +inline std::string to_string(TransposeParams_ t) { - return to_string(t.block_transfer); + std::ostringstream oss; + oss << t.max_transpose_transfer_src_scalar_per_vector << "," + << t.max_transpose_transfer_dst_scalar_per_vector; + return oss.str(); } template <> -inline std::string to_string(DlBlockTransferC t) +inline std::string to_string>(DlTransfer<4> t) { - return to_string(t.epilogue); + std::ostringstream oss; + oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); + return oss.str(); } template <> -inline std::string to_string(DlTransferABC t) +inline std::string to_string>(DlTransfer<5> t) { std::ostringstream oss; oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c); @@ -234,7 +272,13 @@ inline std::string to_string(ThreadBlock_ t) } template <> -inline std::string to_string(XdlGemm_ t) +inline std::string to_string(FwdXdlGemm_ t) +{ + return to_string(t.gridwise_gemm); +} + +template <> +inline std::string to_string(BwdXdlGemm_ t) { return to_string(t.gridwise_gemm); } @@ -245,33 +289,40 @@ inline std::string to_string(WmmaGemm_ t) return to_string(t.gridwise_gemm); } -template <> -inline std::string to_string(Transfer_ t) +template +inline std::string to_string(Transfer_ t) { return to_string(t.transfer); } template <> -inline std::string to_string(ConvSpecialization_ t) +inline std::string to_string(ConvSpecializationFwd_ t) { std::ostringstream oss; oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization); return oss.str(); } +template <> +inline std::string to_string(ConvSpecializationBwdWeight_ t) +{ + std::ostringstream oss; + oss << to_string(t.bwd_weight_specialization); + return oss.str(); +} + template <> inline std::string to_string(Prefetch_ t) { std::ostringstream oss; - oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << "," - << to_string(t.loop_scheduler); + oss << t.num_gemm_k_prefetch_stages << "," << to_string(t.loop_scheduler); return oss.str(); } template <> inline std::string to_string(BlockGemm_ t) { - return to_string(t.block_gemm); + return to_string(t.block_gemm_pipeline); } template <> @@ -287,7 +338,13 @@ inline std::string to_string(DlThreadCluster_ t) } template <> -inline std::string to_string(DlTransfer_ t) +inline std::string to_string>(DlTransfer_<4> t) +{ + return to_string(t.transfer); +} + +template <> +inline std::string to_string>(DlTransfer_<5> t) { return to_string(t.transfer); } @@ -299,8 +356,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -309,8 +366,8 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -320,7 +377,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } @@ -332,7 +389,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) << "," << to_string(static_cast(t)) << "," - << to_string(static_cast(t)); + << to_string(static_cast>(t)); return oss.str(); } @@ -340,7 +397,102 @@ template <> inline std::string to_string( ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t) { - return to_string(t.base_algorithm); + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast(t)) << "," + << to_string(static_cast>(t)); + return oss.str(); +} + +template <> +inline std::string to_string( + ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle t) +{ + std::ostringstream oss; + oss << to_string(static_cast(t)) << "," << to_string(static_cast(t)) + << "," << to_string(static_cast>(t)); + return oss.str(); } } // namespace ck_tile::builder::test diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 3b12e7feb0..4f884b1df3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + static bool __host__ __device__ BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { @@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1 PrefetchStages; } + __host__ __device__ static bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } static TailNumber BlockLoopTailNum(index_t num_loop) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp index ade8035877..2154f35815 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/stream_utility.hpp" + #include "device_grouped_gemm.hpp" namespace ck { @@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +struct TileLoopKernelConfig +{ + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + static int GetCuBlocks() + { + int BLOCK_WAVES = BlockSize / get_warp_size(); + return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + } + + template + static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = GetKernelOccupancy(kernel); + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks()); + } + + template + static int GetKernelOccupancy(const KernelFunction& kernel) + { + int occupancy = 0; + ck::hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + return occupancy; + } + + static int GetComputeUnitCount() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + ck::hip_check_error(hipGetDevice(&dev)); + ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + } +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 30c1b1d490..bc072a7019 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -50,7 +50,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3( + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d( typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, @@ -858,30 +858,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } @@ -897,30 +899,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { if(gemm_arg.KBatch > 1) { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3< - GridwiseGemm, - remove_reference_t, - remove_reference_t, - remove_reference_t< - DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1807dc1d9f..d3bf2a364a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -52,19 +52,20 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight_multiple_d( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) @@ -568,7 +569,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle int max_occupancy = 0; hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, - kernel_batched_gemm_xdlops_bwd_weight< + kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, @@ -841,7 +842,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; - const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight_multiple_d< GridwiseGemm, ADataType, BDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp new file mode 100644 index 0000000000..b7c0d89e0f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -0,0 +1,689 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ uint8_t p_shared[LDS_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + // Create A&B grid pointer containing their single tensors + typename GridwiseGemm::AsGridPointer p_as_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_a_grid)); + typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple( + static_cast(gemm_desc_ptr[group_id].p_b_grid)); + + // Make a DsGridPointer instance containing all D tensors + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + std::array stride_Ds; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i]; + }); + + index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + std::array{gemm_desc_ptr[group_id].StrideA}, + std::array{gemm_desc_ptr[group_id].StrideB}, + stride_Ds, + gemm_desc_ptr[group_id].StrideE, + 1); + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + constexpr TailNumber TailNum = TailNumber::Full; + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + p_as_grid, + p_bs_grid, + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + b2c_tile_map, + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx11__) || defined(__gfx12__)) +} + +template + +struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by GridwiseOp. + false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class. + + using KernelConfig = TileLoopKernelConfig; + using KernelArguments = GroupedGemmKernelArgument; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + const std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto kernel = GetKernelFunction(); + + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + // run multiple kernels + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_wmma; + return kernel; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + bool supported = true; + for(index_t i = 0; i < arg.group_count_; ++i) + { + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + + typename GridwiseGemm::Argument gridwise_arg( + std::array{nullptr}, // p_a_grid, + std::array{nullptr}, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + std::array{arg.gemm_descs_[i].stride_A_}, + std::array{arg.gemm_descs_[i].stride_B_}, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + 1, // KBatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + false); + + bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg); + supported = supported && group_arg_valid; + + if(!group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gridwise_arg.Print(); + } + } + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + int occupancy = GetKernelOccupancy(); + int num_cu = KernelConfig::GetComputeUnitCount(); + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpyAsync(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 4492e6474f..a9e81f5563 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -26,6 +27,18 @@ namespace ck { namespace tensor_operation { namespace device { +// Dummy kernel to use as a fallback in the kernel selection logic +// Is not used in practice, but only used in case of misconfigured parameters +template +__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*, + const index_t, + const AElementwiseOperation, + const BElementwiseOperation, + const CDEElementwiseOperation) +{ +} /// /// @brief Entry point kernel for device-wide Grouped GEMM operation. /// @@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; + using KernelConfig = TileLoopKernelConfig; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; @@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop index_t tile_count_; }; - struct KernelConfig - { - // The oversubscription factor for the number of blocks that can simultaneously reside on - // GPU. - static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; - // static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); - static constexpr int CU_SIMDS = 4; - // Assume we want to have at most 2 waves per SIMD - // static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - static int GetCuBlocks() - { - int BLOCK_WAVES = BlockSize / get_warp_size(); - return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); - } - }; - // Invoker struct Invoker : public BaseInvoker { @@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const void* dev_gemm_args, const StreamConfig& stream_config) const { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + const auto kernel = GetKernelFunction(); return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); } - template - int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, - const StreamConfig& stream_config) const - { - // Calculate max number of workgroups that can simultaneously reside on the CU. - int occ_num_blocks = 0; - size_t dyn_shared_mem_per_blk = 0; - hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); - - int cu_count = getAvailableComputeUnitCount(stream_config); - - if(stream_config.log_level_ > 0) - { - std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks - << ", available CUs count: " << cu_count << ", occup. grid size: " - << ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count - << std::endl; - } - - return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()); - } - template float LaunchKernel(const KernelFunction& kernel, const Argument& arg, const void* dev_gemm_args, const StreamConfig& stream_config) const { - int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config); if(stream_config.log_level_ > 0) { @@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return IsSupportedArgument(*dynamic_cast(p_arg)); } - static int GetKernelOccupancy() + template + static auto GetKernelFunction() + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return kernel; + } + + static auto GetKernelFunction() { - int occupancy = 0; if(get_warp_size() == 64) { if constexpr(NXdlPerWave64 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } else { - if constexpr(NXdlPerWave32 > 0) { - const auto kernel = kernel_grouped_gemm_multiple_d_xdl; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + const auto kernel = GetKernelFunction(); + return kernel; } } - return occupancy; + + // This is here to handle the case where MXdlPerWave/NxdPerWave is too small + // This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called + // before we need a fallback kernel to return here. + return kernel_dummy; + } + + static int GetKernelOccupancy() + { + const auto kernel = GetKernelFunction(); + return KernelConfig::GetKernelOccupancy(kernel); } static auto MakeArgument(std::vector& p_As, @@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return Argument{p_As, p_Bs, @@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop CDEElementwiseOperation cde_elementwise_op) override { int occupancy = GetKernelOccupancy(); - int num_cu; - - hipDeviceProp_t dev_prop; - hipDevice_t dev; - hip_check_error(hipGetDevice(&dev)); - hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; + int num_cu = KernelConfig::GetComputeUnitCount(); return std::make_unique(p_As, p_Bs, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 6914def110..714d567020 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -126,7 +126,6 @@ template + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK; // PermuteB not supported by DeviceBatchedGemm base class. + false, // PermuteA not supported by GridwiseOp + false>; // PermuteB not supported by DeviceGroupedGemm base class using CGridDesc_M_N = remove_cvref_t( @@ -779,7 +776,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK + typename LDSTypeB = ComputeTypeB, + bool NonTemporalLoadB = false> struct DeviceMoeGemmBlockScale : public DeviceGemmMultipleD_BlockScale_BPreshuffle; + LDSTypeB, + NonTemporalLoadB>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 2c17b82608..dc102ef805 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -8,6 +8,7 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/quantization_operation.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -236,8 +237,9 @@ struct MultiplyAdd const half_t& d0, const half_t& d1) const { - const half_t y = type_convert(c) * d0 + d1; - e = y; + const half_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(bhalf_t& e, @@ -245,8 +247,9 @@ struct MultiplyAdd const bhalf_t& d0, const bhalf_t& d1) const { - const bhalf_t y = type_convert(c) * d0 + d1; - e = y; + const bhalf_t y = + type_convert(c * type_convert(d0) + type_convert(d1)); + e = y; } template <> __host__ __device__ void operator()(float& e, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index c3c14edfb8..9f7fd47083 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 struct Problem { __host__ Problem() = default; - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - std::array StrideAs_, - std::array StrideBs_, - std::array StrideDs_, - index_t StrideE_, - index_t KBatch_) + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideE_, + index_t KBatch_) : M{M_}, N{N_}, K{K_}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 11e9a6dbf7..79549d6385 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Calculate grid size taking into account splitk (KBatch) // 2D grid (x,z) - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + __host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); } // Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch) // 3D grid (x,y,z) - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + __host__ __device__ static auto + CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); } - __host__ static auto CalculateMPadded(index_t M) + __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); } - __host__ static auto CalculateNPadded(index_t N) + __host__ __device__ static auto CalculateNPadded(index_t N) { return math::integer_least_multiple(N, NPerBlock); } - __host__ static auto CalculateKPadded(index_t K) + __host__ __device__ static auto CalculateKPadded(index_t K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; } - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); } - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); } - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * KPerBlock; } - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); auto K_t = K_Batch * KReadVec; return (K + K_t - 1) / K_t * KReadVec; } - __host__ static auto CalculateMBlock(index_t M) + __host__ __device__ static auto CalculateMBlock(index_t M) { return math::integer_divide_ceil(M, MPerBlock); } - __host__ static auto CalculateNBlock(index_t N) + __host__ __device__ static auto CalculateNBlock(index_t N) { return math::integer_divide_ceil(N, NPerBlock); } @@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return true; } - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) { const index_t num_loop = K / KPerBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index c556dbec10..3b98798833 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -173,7 +173,8 @@ template + typename LDSTypeB = BDataType, + bool NonTemporalLoadB = false> struct GridwiseMoeGemmBlockScale { using AScaleType = float; @@ -1202,6 +1203,13 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { +#if defined(__gfx942__) || defined(__gfx950__) + constexpr auto b_coherence_flag = NonTemporalLoadB + ? AmdBufferCoherenceEnum::WAVE_NT1 + : AmdBufferCoherenceEnum::DefaultCoherence; +#else + constexpr auto b_coherence_flag = AmdBufferCoherenceEnum::DefaultCoherence; +#endif ignore = b_element_op; index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); @@ -1300,15 +1308,16 @@ struct GridwiseMoeGemmBlockScale const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1465,9 +1474,11 @@ struct GridwiseMoeGemmBlockScale if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; - const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto b_grid_buf_up = + make_dynamic_buffer( + p_b_grid_up + + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType, @@ -1485,9 +1496,10 @@ struct GridwiseMoeGemmBlockScale KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf_up = + make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( + const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid + expert_id * expert_scale_stride, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -2227,9 +2247,11 @@ struct GridwiseMoeGemmBlockScale if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; - const auto b_grid_buf_up = make_dynamic_buffer( - p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, - b_grid_desc_bpreshuffled.GetElementSpaceSize()); + const auto b_grid_buf_up = + make_dynamic_buffer( + p_b_grid_up + + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< BDataType, BDataType, @@ -2247,9 +2269,10 @@ struct GridwiseMoeGemmBlockScale KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; - const auto b_scale_grid_buf_up = make_dynamic_buffer( - p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, - b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf_up = + make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2 namespace ck { @@ -220,4 +221,49 @@ constexpr Tuple tie(Args&... args) noexcept return {args...}; } +// +// tuple_map: Map tuple with a different type +// e.g. tuple_map> becomes Tuple, Wrapper, Wrapper> +// +template