diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11648bfd27..e4d0d47a2e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @zjing14 @asroy @junliume @illsilin @carlushuang +* @zjing14 @asroy @junliume @illsilin @carlushuang @aosewski # Documentation files docs/* @saadrahim @LisaDelaney *.md @saadrahim @LisaDelaney diff --git a/CHANGELOG.md b/CHANGELOG.md index abca69142e..c721039523 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,53 +2,66 @@ Full documentation for Composable Kernel is not yet available. -## (Unreleased) CK for ROCm 6.0.0 +## (Unreleased) CK ### Fixes - - Fixed a hazard associated with inline v_dot (#808) - - Fixed two bugs in grouped convolution backward data without K padding (#848 #876) +None ### Optimizations None ### Additions -- Added an image to a column kernel (#867) -- Added a column to an image kernel (#930) -- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) -- Grouped convolution support for small K and C (#822 #879 #897) -- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) -- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) -- Support for Batched Gemm DL (#732) -- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) ### Changes - - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) +None + +## CK for ROCm 6.0.0 + +### Fixes + * Fixed a hazard associated with inline v_dot (#808) + * Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + +### Optimizations +None + +### Additions +* Added an image to a column kernel (#867) +* Added a column to an image kernel (#930) +* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) +* Grouped convolution support for small K and C (#822 #879 #897) +* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) +* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) +* Support for Batched Gemm DL (#732) + +### Changes + * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) ## CK 0.2.0 for ROCm 5.7.0 ### Fixes -- Fixed a bug in 6-dimensional kernels (#555) -- Fixed a test case failure with grouped convolution backward weight (#524) +* Fixed a bug in 6-dimensional kernels (#555) +* Fixed a test case failure with grouped convolution backward weight (#524) ### Optimizations -- Improved the performance of the normalization kernel +* Improved the performance of the normalization kernel ### Additions -- New CMake flags: - - "DL_KERNELS"-- Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances - - "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types - - "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler -- New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler -- Support for MI300A/MI300X -- Support for AMD RDNA 3 -- New user tutorial (#563) -- Additional instances for irregular GEMM sizes (#560) -- New inter-wave consumer-producer programming model for GEMM kernels (#310) -- GEMM with support multiple elementwise fusions (multi-D) (#534) -- Multi-embeddings support (#542) -- AMD RDNA 3 blockwise GEMM and real GEMM support (#541) -- AMD RDNA grouped convolution backward weight support (#505) -- MaxPool and AvgPool forward (#815); MaxPool backward (#750) +* New CMake flags: + * "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler +* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler +* Support for MI300A/MI300X +* Support for AMD RDNA 3 +* New user tutorial (#563) +* Additional instances for irregular GEMM sizes (#560) +* New inter-wave consumer-producer programming model for GEMM kernels (#310) +* GEMM with support multiple elementwise fusions (multi-D) (#534) +* Multi-embeddings support (#542) +* AMD RDNA 3 blockwise GEMM and real GEMM support (#541) +* AMD RDNA grouped convolution backward weight support (#505) +* MaxPool and AvgPool forward (#815); MaxPool backward (#750) ### Changes None diff --git a/Dockerfile b/Dockerfile index a805285a77..48ee97eec2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -122,7 +122,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -130,7 +130,7 @@ RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index e333a35ecd..80e7b044f1 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -304,7 +304,7 @@ def buildHipClangJob(Map conf=[:]){ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 20, unit: 'HOURS') + timeout(time: 48, unit: 'HOURS') { cmake_build(conf) } @@ -348,7 +348,7 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -560,7 +560,7 @@ def Build_CK(Map conf=[:]){ sh """#!/bin/bash mkdir -p build ls -ltr - CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install" + CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" cmake --build build -- -j """ } @@ -657,7 +657,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION= 0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" pipeline { @@ -680,7 +680,7 @@ pipeline { string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, amd-mainline-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', @@ -713,6 +713,10 @@ pipeline { name: "RUN_CPPCHECK", defaultValue: false, description: "Run the cppcheck static analysis (default: OFF)") + booleanParam( + name: "RUN_PERFORMANCE_TESTS", + defaultValue: false, + description: "Run the performance tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -755,7 +759,11 @@ pipeline { -o -iname \'*.cl\' \ | grep -v 'build/' \ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ - /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include --file-filter=*.cpp --enable=all --output-file=ck_cppcheck.log" + /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ + -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \ + -D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ + -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -886,7 +894,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx908 || gfx90a")} @@ -902,7 +910,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx90a")} @@ -921,6 +929,10 @@ pipeline { parallel { stage("Process results"){ + when { + beforeAgent true + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } + } agent { label 'mici' } steps{ process_results() diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index 246f877cde..b7b3c830ed 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,6 +1,9 @@ add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) +add_executable(client_layernorm2d_bwd_gamma_beta layernorm2d_bwd_gamma_beta.cpp) +target_link_libraries(client_layernorm2d_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..98b394add6 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * N); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * N); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(MeanInvStdDataType) * M * 2 + sizeof(DGammaDataType) * N + + sizeof(DBetaDataType) * N; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index deb50f6fce..e04c26d8e7 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,5 +1,8 @@ add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) +add_executable(client_groupnorm_bwd_gamma_beta groupnorm_bwd_gamma_beta.cpp) +target_link_libraries(client_groupnorm_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..c2fbe285df --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + std::vector strideDGammaBeta = {C, 1}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * G * C); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * G * C); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DGammaDataType) * G * C + sizeof(DBetaDataType) * G * C; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/25_tensor_transforms/tensor_transform.cpp b/client_example/25_tensor_transforms/tensor_transform.cpp deleted file mode 100644 index 41ceec1cb5..0000000000 --- a/client_example/25_tensor_transforms/tensor_transform.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" - -#include "ck/utility/number.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/sequence.hpp" - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/multi_index_transform_helper.hpp" - -static constexpr auto I0 = ck::Number<0>{}; -static constexpr auto I1 = ck::Number<1>{}; -static constexpr auto I2 = ck::Number<2>{}; - -using DataType = int; - -template -void Print1d(const Desc& desc) -{ - std::cout << "Print1d" << std::endl; - for(ck::index_t w = 0; w < desc.GetLength(I0); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " "; - } - std::cout << std::endl; -} - -template -void Print2d(const Desc& desc) -{ - std::cout << "Print2d" << std::endl; - for(ck::index_t h = 0; h < desc.GetLength(I0); h++) - { - for(ck::index_t w = 0; w < desc.GetLength(I1); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " "; - } - std::cout << std::endl; - } -} - -template -void Print3dCustom(const Desc& desc) -{ - std::cout << "Print3dCustom" << std::endl; - for(ck::index_t d = 0; d < desc.GetLength(I0); d++) - { - for(ck::index_t h = 0; h < desc.GetLength(I1); h++) - { - for(ck::index_t w = 0; w < desc.GetLength(I2); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " "; - } - std::cout << std::endl; - } - std::cout << std::endl; - } -} - -int main() -{ - // Tensor descriptor traverse in row-major (need to reverse dims) - std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl; - // Basic descriptor 0, 1, 2, ... 30, 31 - // (dims:4,8 strides:1,4) - const auto desc_4x8_s1x4 = - ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}), - ck::make_tuple(ck::Number<1>{}, ck::Number<4>{})); - std::cout << "dims:4,8 strides:1,4" << std::endl; - Print2d(desc_4x8_s1x4); - - using Cord1x1Type = ck::Tuple, ck::Number<1>>; - constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{}); - std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:4,(2,4) strides:2,(1,8) - const auto desc_4x2x4_s2x1x8 = - ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8)); - // Transform to 2d (column-major, need to to reverse dims) - const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor( - desc_4x2x4_s2x1x8, - ck::make_tuple(ck::make_pass_through_transform(4), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - - std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; - Print2d(desc_4x2x4_s2x1x8_merged); - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:(2,2),(2,4) strides:((1,4),(2,8) - const auto desc_2x2x2x4_s1x4x2x8 = - ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); - // Transform to 2d - const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - // Transform to 3d - const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8, - ck::make_tuple(ck::make_pass_through_transform(2), - ck::make_pass_through_transform(2), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); - - std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; - Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d); - Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d); - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:((2,2),2),4 strides:((1,4),2),8 - // Transform to 2d - const auto desc_2x2x2x4_s1x4x2x8_nested = - ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), - ck::make_pass_through_transform(2), - ck::make_pass_through_transform(4)), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))), - ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), - ck::make_tuple(ck::Sequence<0>{})); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested_merged_3d, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)), - ck::make_pass_through_transform(4)), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - - std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; - Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d); - Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d); - Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d); - - return 0; -} diff --git a/client_example/25_tensor_transforms/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt similarity index 55% rename from client_example/25_tensor_transforms/CMakeLists.txt rename to client_example/25_wrapper/CMakeLists.txt index d1543fb0ef..eb3be0e6c8 100644 --- a/client_example/25_tensor_transforms/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -1,4 +1,4 @@ -add_executable(client_tensor_transform tensor_transform.cpp) -target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations) add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) +add_executable(client_wrapper_img2col wrapper_img2col.cpp) +target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp similarity index 98% rename from client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp rename to client_example/25_wrapper/tensor_transform_using_wrapper.cpp index de9fcde0b4..4b25d85e2d 100644 --- a/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp +++ b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp new file mode 100644 index 0000000000..35074be4c1 --- /dev/null +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" + +static constexpr ck::index_t NumDimSpatial = 3; +using DataType = float; +using InputLayout = ck::tensor_layout::convolution::NDHWGC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// Test copy from Global to Global through LDS and VGPR +template +__global__ void DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayoutShape thread_layout) +{ + const ck::index_t block_idx = static_cast(blockIdx.x); + + // Get local tiles for global memory + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + + // Get partition per thread + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); + + // Perform copy + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + constexpr ck::index_t scalar_per_vector = 4; + ck::wrapper::copy(input_local_partition, + output_local_partition); +} + +void PerformImageToColumnPad0(const ck::index_t G, + const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t C, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + std::array filter_strides, + std::array filter_dilations) +{ + const ck::index_t ZYXC = Z * Y * X * C; + const ck::index_t GC = G * C; + + // shape: (G, (Wo, Ho, Do, N)), (C, X, Y, Z)) + const auto shape = ck::make_tuple(ck::make_tuple(G, ck::make_tuple(Wo, Ho, Do, N)), + ck::make_tuple(C, X, Y, Z)); + const auto in_strides = + ck::make_tuple(ck::make_tuple(C, + ck::make_tuple(filter_strides[2] * GC, + filter_strides[1] * Wi * GC, + filter_strides[0] * Hi * Wi * GC, + Di * Hi * Wi * GC)), + ck::make_tuple(1, + filter_dilations[2] * GC, + filter_dilations[1] * Wi * GC, + filter_dilations[0] * Hi * Wi * GC)); + const auto in_layout = ck::wrapper::make_layout(shape, in_strides); + + const auto out_strides = ck::make_tuple( + ck::make_tuple( + ZYXC, + ck::make_tuple(ZYXC * G, Wo * ZYXC * G, Ho * Wo * ZYXC * G, Do * Ho * Wo * ZYXC * G)), + ck::make_tuple(1, C, X * C, Y * X * C)); + const auto out_layout = ck::wrapper::make_layout(shape, out_strides); + + const ck::index_t input_size = N * Di * Hi * Wi * GC; + // Global memory buffers + SimpleDeviceMem in_buf(input_size * sizeof(DataType)); + SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); + + // User can choose appropriate number of threads and sizes per block + const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + // This example doesn't support padding, user should select tile sizes + // which divides the shape completely + const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); + + // Create buffers for global memory + auto input_tensor_global = ck::wrapper::make_tensor( + static_cast(in_buf.GetDeviceBuffer()), in_layout); + auto output_tensor_global = ck::wrapper::make_tensor( + static_cast(out_buf.GetDeviceBuffer()), out_layout); + + const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)) * + ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + + const auto kernel = DeviceImageToColumnPad0; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size), + dim3(ck::wrapper::size(thread_layout)), + 0, + input_tensor_global, + output_tensor_global, + tile_shape, + thread_layout); + + std::size_t num_btype = G * N * Do * Ho * Wo * ZYXC * 2 * sizeof(DataType); + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << std::endl; +} + +int main(int argc, char* argv[]) +{ + constexpr ck::index_t G = 4; // number of groups + constexpr ck::index_t N = 32; // batch + constexpr ck::index_t C = 64; // input channel (per group) + constexpr ck::index_t Z = 3; // filter D + constexpr ck::index_t Y = 3; // filter H + constexpr ck::index_t X = 3; // filter W + constexpr ck::index_t Di = 9; // input D + constexpr ck::index_t Hi = 9; // input H + constexpr ck::index_t Wi = 7; // input W + constexpr ck::index_t Do = 7; // output D + constexpr ck::index_t Ho = 7; // output H + constexpr ck::index_t Wo = 5; // output W + PerformImageToColumnPad0(G, + N, + Di, + Hi, + Wi, + Do, + Ho, + Wo, + C, + Z, + Y, + X, + {1, 1, 1} /*filter_strides*/, + {1, 1, 1} /*filter_dilations*/); + return 0; +} diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 23a4c4bb91..88142aa373 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.31.0 +rocm-docs-core==0.33.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 1e5e688dac..12414c7470 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.31.0 +rocm-docs-core==0.33.0 # via -r requirements.in six==1.16.0 # via diff --git a/docs/wrapper.rst b/docs/wrapper.rst index c050f17caf..79b6c75580 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -18,8 +18,7 @@ Description The CK library provides a lightweight wrapper for more complex operations implemented in -the library. It allows indexing of nested layouts using a simple interface -(avoiding complex descriptor transformations) and memory access (using Tensor). +the library. Example: @@ -54,6 +53,11 @@ Output:: 1 5 9 13 17 21 25 29 2 6 10 14 18 22 26 30 + +Advanced examples: + +* `Image to column `_ + ------------------------------------- Layout ------------------------------------- diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 56897571c7..5b71cd1548 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -19,6 +19,9 @@ add_custom_target(example_gemm_xdl) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) +add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e55ae14013..43c0cfe2e0 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp new file mode 100644 index 0000000000..eba0ea9d11 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using F16 = ck::half_t; +using F32 = float; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2< + ALayout, BLayout, CLayout, + F16, F16, F16, F32, F16, + PassThrough, PassThrough, PassThrough, GemmDefault, + 2, 256, + 256, 256, + 32, 8, 4, + 32, 32, + 4, 4, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::LoopScheduler::Default, ck::PipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index f6238c7aa5..fb4f383fae 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index f206bbeb41..1d0b0f7861 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; #include "run_convnd_fwd_max_example.inc" int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } +#endif diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index e363dc5c12..62295c57eb 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -272,15 +272,14 @@ int main(int argc, char* argv[]) { for(int m = 0; m < M; ++m) { - auto reduce0_acc = reduce0_op.GetIdentityValue(); - auto reduce1_acc = reduce1_op.GetIdentityValue(); - + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + ReduceAccDataType d0_val = 0; + ReduceAccDataType d1_val = 0; for(int n = 0; n < N; ++n) { auto c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - ReduceAccDataType d0_val; - ReduceAccDataType d1_val; UnaryIdenticElementOp{}(d0_val, c_val); UnarySquareElementOp{}(d1_val, c_val); diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 5494563fdd..6f91d51a5f 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +#endif diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index d166214c33..2caee6b8dc 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm1 */ -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e9bd5c552d..e3690984ab 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con if(config.time_kernel) { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index 74fb16e15b..dc54bc30ef 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -42,7 +42,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 80f6e9ae05..cf7b1ce3a8 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 39032fa123..788f38ec52 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -32,6 +32,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_pipeline_v4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + + // HotLoopInstList::Print(); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __device__ static constexpr auto HotLoopScheduler() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_buffer_load_inst = + HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA + }); + } + + template + __device__ static constexpr auto TailScheduler() + { + } + + template <> + __device__ static constexpr auto TailScheduler<1>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_write_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA + }); + } + + template <> + __device__ static constexpr auto TailScheduler<2>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_read_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA + }); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + // Inst List: + // ds_read_b128: 16 + // ds_write_b128: 8 + // buffer_load_dwordx4: 16 + // v_mfma: 0 + // ------------------------------------------------------------------------------------------- + + // Global prefetch 1th, Fill Ping LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1th, Fill Ping Reg + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 2th, Fill Pong LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3rd + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + // ------------------------------------------------------------------------------------------- + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 2; + } while(i < (num_loop - 3)); + } + + // tail + if constexpr(TailNum == 3) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == 2) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000..d49c63f147 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffleV2" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 7bb47e9d3c..6266fb40f0 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif } template @@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt( + const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, + index_t N, + index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif + } + + template + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) + : BlockToCTileMap_Grouped_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + auto group_id = block_1d_id % GroupNum; + auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt + : BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + using BlockToCTileMap_Grouped_M00_N0_M01Adapt:: + BlockToCTileMap_Grouped_M00_N0_M01Adapt; +}; + +// columns of row-vectors +// This C-tile map dynamically adjusts N01 when C-tile index is out of range +template +struct BlockToCTileMap_N00_M0_N01Adapt; + +template +struct BlockToCTileMap_N00_M0_N01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8) + : M_(M), N_(N), N01_(N01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_); + } +#endif + } + + template + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t N01 = 8) + : BlockToCTileMap_N00_M0_N01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_M0 = block_1d_id % M0; + index_t idx_N0 = block_1d_id / M0; + + const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_; + + index_t idx_N00 = idx_N0 / N01_; + index_t idx_N01 = idx_N0 % N01_; + index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0; + + /** + * idxN0 + * + * |< mtx N >| + * + * |<---N01--->| + * - |-----------|-----------|-----------|-----|-----|- + * ^ | 0 ----------> 1 | | | | + * | | / | | | | M_0 MPerBlock + * | / | | | | + * |------/----------------|-----------|-----|-----|- + * | | | | | | | + * idxM0 | V | | | | | M_1 MPerBlock + * | 2 ----------> 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | blockid | | | | + * | | 5 | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * Example: + * assume: + * N0 = 5 + * M0 = 4 + * block_1d_id = 5 + * N01 = 2 + * + * idx_M0 = 1 + * idx_N0 = 1 + * N01_adapt = 2 + * idx_N00 = 0 + * idx_N01 = 1 + * idx_M0_N01_local = 5 + * output {2, 1} + */ + + return make_tuple(idx_M0_N01_local / N01_adapt, + idx_M0_N01_local % N01_adapt + idx_N00 * N01_); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t N01_; +}; + // 2D slices of column-vectors in 3D space // This C-tile map dynamically adjusts M01 when C-tile index is out of range template diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000..2ad2dd9915 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,1153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(num_k_loop < 4) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return num_loop > 3; + } + + __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + if(num_loop % 2 == 1) + return 3; + else + return 2; + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } +#if 0 + if(threadIdx.x == 0){ + printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n", + get_block_1d_id(), + block_work_idx[I0], + block_work_idx[I1], + __smid()>>6 & 0xf, + __smid()>>4 & 0x3, + __smid() & 0xf); + } +#endif + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + // BlockSize, + // ComputeType, + // FloatGemmAcc, + // decltype(a_block_desc_ak0_m_ak1), + // decltype(b_block_desc_bk0_n_bk1), + // MPerXdl, + // NPerXdl, + // MXdlPerWave, + // NXdlPerWave, + // KPack, + // LoopSched>(); + auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4< + BlockSize, + ComputeTypeA, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; // TransposeC + + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 7bab488e58..87e1e0e8d9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } else { return transform_tensor_descriptor( @@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } else { return transform_tensor_descriptor( diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index d367ad8df5..31ae71880a 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -189,6 +189,7 @@ struct vector_type } }; +int static err = 0; template struct vector_type { @@ -221,6 +222,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } template @@ -236,6 +241,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } }; @@ -278,6 +287,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } template @@ -298,6 +311,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } }; @@ -347,6 +364,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } template @@ -372,6 +393,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } }; @@ -428,6 +453,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } template @@ -458,6 +487,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } }; @@ -520,6 +553,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } template @@ -554,6 +591,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } }; @@ -623,6 +664,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } template @@ -662,6 +707,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } }; @@ -737,6 +786,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } template @@ -780,6 +833,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } }; @@ -861,6 +918,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } template @@ -908,6 +969,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } }; diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index 2cafc3e6f2..0916e4604e 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,12 @@ struct is_known_at_compile_time static constexpr bool value = false; }; +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + template <> struct is_known_at_compile_time { diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 1643eb7383..39b5c79c67 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,22 +14,28 @@ namespace wrapper { * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes * (e.g. ((4, 2), 2)), nested dimensions are merged. - * \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims. + * \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims. */ -template +template struct Layout { private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - // Generate default idxs tuple (idx with all merged nested shapes) + /** + * \brief Generate default indices tuple (idx with all merged nested shapes) + * + * \param shape Shape to align. + * \return Multi idx tuple with zeros. + */ template - __host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple&) + __host__ __device__ constexpr static auto + GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple& shape) { return generate_tuple( [&](auto) { - if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) { // runtime layout return index_t(0); @@ -43,11 +49,18 @@ struct Layout Number::Size()>{}); } - // Generate LowerDims in Compile-time for MergeTrasform using passed Type - // If element of Tuple is also tuple, then merge (generate sequence for merge) - // If tuple is element, then pass through (sequence with one element) + /** + * \brief Generate lower dims in compile-time for the Merge transform using + * provided type. If element of nested Tuple is also a tuple, then + * merge (generate sequence for merge). If tuple is element, then pass + * through (sequence with one element). + * + * \param shape Shape to align. + * \return LowerDims for MergeTrasform. + */ template - __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple&) + __host__ __device__ constexpr static auto + GenerateLowerDim([[maybe_unused]] const Tuple& shape) { if constexpr(Idx::value == 0) { @@ -87,11 +100,17 @@ struct Layout } } - // Iterate over nested tuples in shape - // Unroll nested tuples to align Tuple to Tuple - // Example idx: (1, 1), 1, 1 - // Example shape: (2, (2, 2)), 2, (2, 2) - // Unrolled shape: 2, (2, 2), 2, (2, 2) + /** + * \brief Iterate over the nested tuples in the shape. + * Unroll nested tuples to align Tuple to Tuple + * Example idx: (1, 1), 1, 1 + * Example shape: (2, (2, 2)), 2, (2, 2) + * Unrolled shape: 2, (2, 2), 2, (2, 2) + * + * \param shape Layout shape. + * \param idx Idx to align. + * \return Algined shape. + */ template __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple& shape, const Tuple& idx) @@ -126,6 +145,13 @@ struct Layout } } + /** + * \brief Merge descriptor to 1D. + * + * \param shape Layout shape. + * \param desc Descriptor to merge. + * \return 1D descriptor. + */ template __host__ __device__ constexpr static auto MakeMerge1d(const Tuple& shape, const DescriptorToMerge& desc) @@ -137,18 +163,41 @@ struct Layout const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); const auto upper_dims = make_tuple(Sequence<0>{}); // Merge to 1d - return transform_tensor_descriptor( - desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) + { + return transform_tensor_descriptor( + desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + } + else + { + // If the descriptor is known at the compilation time, + // use `make_merge_transform_v1_carry_check` because it doesn't use + // memcpy. + return transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform_v1_carry_check(merge_elems)), + lower_dims, + upper_dims); + } } - // Merge nested shape dims when corresponding index is also nested. - // Input desc shape: 2, 2, 2, 2, 2, 2 - // Example idx: 1, 1, 1, 1 - // Example shape: 2, (2, 2), 2, (2, 2) - // Merged shape: 2, 4, 2, 4 + /** + * \brief Merge nested shape dims when corresponding index is also merged. + * Input desc shape: 2, 2, 2, 2, 2, 2 + * Example idx: 1, 1, 1, (1, 1) + * Example shape: 2, (2, 2), 2, (2, 2) + * Merged shape: 2, 4, 2, 2, 2 + * + * \param shape Layout shape. + * \param idxs Indexes to align descriptor. + * \param desc Descriptor to merge. + * \return Aligned descriptor to idx. + */ template - __host__ __device__ constexpr static auto CreateMergedDescriptor( - const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + __host__ __device__ constexpr static auto + CreateMergedDescriptor(const Tuple& shape, + [[maybe_unused]] const Tuple& idxs, + DescriptorToMerge& desc) { const auto transforms = generate_tuple( [&](auto i) { @@ -160,7 +209,17 @@ struct Layout // If shape element is tuple and idx element is Number, then merge // Unroll and reverse tuple to traverse column-major const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); - return make_merge_transform(merge_elems); + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) + { + return make_merge_transform(merge_elems); + } + else + { + // If the descriptor is known at the compilation time, + // use `make_merge_transform_v1_carry_check` because + // it doesn't use memcpy. + return make_merge_transform_v1_carry_check(merge_elems); + } } else { @@ -185,14 +244,23 @@ struct Layout } using Descriptor1dType = - remove_cvref_t; + remove_cvref_t; using DefaultIdxsTupleType = remove_cvref_t; + public: + /** + * \brief Transform descriptor to align to passed indexes. + * + * \param shape Layout shape. + * \param idxs Indexes to align descriptor. + * \param naive_descriptor Descriptor to merge. + * \return Aligned descriptor to idx. + */ template __host__ __device__ constexpr static auto TransformDesc(const Tuple& shape, - const Tuple& idx, - const UnnestedDescriptorType& naive_descriptor) + const Tuple& idxs, + const UnrolledDescriptorType& naive_descriptor) { if constexpr(Tuple::Size() == I1) { @@ -208,19 +276,18 @@ struct Layout static_assert(Tuple::Size() == Tuple::Size(), "Idx rank and Shape rank must be the same (except 1d)."); // Unroll while IdxDims is nested - const auto aligned_shape = AlignShapeToIdx(shape, idx); + const auto aligned_shape = AlignShapeToIdx(shape, idxs); // Transform correct form of shape - return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor); } } using MergedNestsDescriptorType = remove_cvref_t; + Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>; - public: __host__ __device__ constexpr auto GetElementSpaceSize() const { - return unnested_descriptor_.GetElementSpaceSize(); + return unrolled_descriptor_.GetElementSpaceSize(); } __host__ __device__ Layout() = delete; @@ -232,16 +299,15 @@ struct Layout * \param unnested_descriptor Descriptor */ __host__ __device__ constexpr Layout(const Shape& shape, - const UnnestedDescriptorType& unnested_descriptor) - : shape_(shape) + const UnrolledDescriptorType& unnested_descriptor) + : unrolled_descriptor_(unnested_descriptor), shape_(shape) { // Construct if runtime mode - if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) { - unnested_descriptor_ = unnested_descriptor; - descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_); + descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_); merged_nests_descriptor_ = - TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_); + TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_); } } @@ -254,9 +320,9 @@ struct Layout template __host__ __device__ constexpr index_t operator()() const { - static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(), + static_assert(remove_cvref_t::IsKnownAtCompileTime(), "Compiletime operator used on runtime layout."); - using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{})); + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); } @@ -283,7 +349,7 @@ struct Layout else { // Custom index, need to transform descriptor - const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_); + const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); } } @@ -350,29 +416,55 @@ struct Layout } /** - * \brief Get default descriptor (with the same size as Shape) + * \brief Get descriptor with all nested dimensions merged. + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (4, 2) * - * \return Default descriptor. + * \note The size of merged descriptor is the same as Layout's shape. + * + * \return Merged nests descriptor. */ - __host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const + __host__ __device__ constexpr const MergedNestsDescriptorType& + GetMergedNestingDescriptor() const { return merged_nests_descriptor_; } /** - * \brief Get unnested descriptor (with unrolled dims) + * \brief Get descriptor with all dimensions are merged (1D). + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (8) * - * \return Flatten descriptor. + * \return 1D descriptor. */ - __host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const + __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const { - return unnested_descriptor_; + return descriptor_1d_; + } + + /** + * \brief Get unnested descriptor (with unrolled dims) + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (2, 2, 2) + * + * \return Flattened descriptor. + */ + __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const + { + return unrolled_descriptor_; } private: - UnnestedDescriptorType unnested_descriptor_; + // All dimensions are unrolled + UnrolledDescriptorType unrolled_descriptor_; + // 1D descriptor Descriptor1dType descriptor_1d_; + // All nesting are merged MergedNestsDescriptorType merged_nests_descriptor_; + // Example, shape: ((2, 2), 2) + // UnrolledDescriptorType lengths: (2, 2, 2) + // Descriptor1dType lengths: (8) + // MergedNestsDescriptorType lengths: (4, 2) const Shape shape_; }; diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index aec80f9ca7..7b00fe5500 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -1,16 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "../utils/tensor_utils.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + namespace ck { namespace wrapper { /** - * \brief Perform generic copy between two tensors. Tensors must have the - * same size. + * \brief Perform generic copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. * * \param src_tensor Source tensor. * \param dst_tensor Destination tensor. @@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds } } +/** + * \brief Perform optimized copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. + * + * \tparam DimAccessOrderTuple Tuple with dimension access order. + * \tparam VectorDim Dimension for vectorized read and write. + * \tparam ScalarPerVector Number of scalar per vectorized read and write. + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + */ +template +__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) +{ + static_assert(is_detected::value); + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor(); + const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor(); + + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + + constexpr auto thread_slice_lengths = + generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); + constexpr auto dim_access_order = generate_sequence_v2( + [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); + + if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer) + { + // Perform a copy between DynamicBuffers + auto transfer = ThreadwiseTensorSliceTransfer_v7< + Tuple, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(InMemoryDataOperationEnum::Set)>, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(tie(in_grid_desc), + tie(src_tensor.GetBuffer()), + tie(out_grid_desc), + tie(dst_tensor.GetBuffer())); + } + else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer) + { + // Perform copy from StaticBuffer to DynamicBuffer + const auto src_slice_origin_idxs = + generate_tuple([&](auto) { return I0; }, Number{}); + + auto transfer = + ThreadwiseTensorSliceTransfer_v1r3, + remove_cvref_t, + tensor_operation::element_wise::PassThrough, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + InMemoryDataOperationEnum::Set, + I1, + true>{out_grid_desc, + dst_tensor.GetMultiIdxOffsets(), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(in_grid_desc, + src_slice_origin_idxs, + src_tensor.GetBuffer(), + out_grid_desc, + dst_tensor.GetBuffer()); + } + else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) + { + // Perform copy from DynamicBuffer to StaticBuffer + const auto src_dst_slice_origin = + generate_tuple([&](auto) { return I0; }, Number{}); + constexpr auto src_vector_tensor_lengths = generate_sequence_v2( + [&](auto I) { + if constexpr(I == VectorDim) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + auto transfer = + ThreadwiseTensorSliceTransfer_v4r1, + remove_cvref_t, + decltype(thread_slice_lengths), + decltype(dim_access_order), + decltype(src_vector_tensor_lengths), + decltype(dim_access_order)>{ + src_tensor.GetMultiIdxOffsets()}; + + transfer.Run(in_grid_desc, + src_dst_slice_origin, + src_tensor.GetBuffer(), + out_grid_desc, + src_dst_slice_origin, + dst_tensor.GetBuffer()); + } + else + { + // Perform copy between StaticBuffers + copy(src_tensor, dst_tensor); + } +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index a363641373..57d79c5940 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -10,189 +10,205 @@ namespace ck { namespace wrapper { +namespace detail { +namespace { +/** + * \brief Check if Tuple contains Slice object + * + * \return True if tuple contains Slice object. + */ +template +__host__ __device__ constexpr bool HasSlice(T&&) +{ + return is_detected::value; +} +template +__host__ __device__ constexpr bool HasSlice(Tuple&&) +{ + return (HasSlice(Ts{}) || ...); +} + +/** + * \brief Calculate new shape after slice from parent shape. + * + * \param idxs Tuple of indexes defining slice ranges. + * \param shape Shape which will be sliced. + * \return New tensor shape. + */ +template +__host__ __device__ constexpr auto GetSlicedShape(const Tuple& idxs, + const SlicedShape& shape) +{ + // Pack each value in tuple to remove empty tuples after generation + auto new_shape = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + if constexpr(!detail::HasSlice(tuple_element_t>{})) + { + // if tuple does not have any slice then we can remove dimension + return Tuple<>{}; + } + else + { + // if tuple then recurrence + return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i))); + } + } + else if constexpr(is_detected>>::value) + { + // calculate new dimension + const auto& dim = size(shape.At(num_i)); + const auto val = idxs.At(num_i).range(dim); + return make_tuple(val); + } + else + { + // remove dimension for just value + return Tuple<>{}; + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple<0, 1>(new_shape); +} + +/** + * \brief Generate Freeze for each of nested shape. + * + * \param idx Tuple of start indices for slice. + * \param shape Shape which will be freezed. + * \return Generated freeze transforms. + */ +template +__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return generate_tuple( + [&](auto i) { + // dimension offset from idx + const auto dim = unrolled_shape.At(Number{}); + const auto dim_idx = idx % dim; + idx /= dim; + return make_freeze_transform(dim_idx); + }, + Number{}); +} + +/** + * \brief Generate transforms for slice tensor. + * + * \param idx Tuple of start indices for slice. + * \param shape Shape which will be sliced. + * \return Generated transforms. + */ +template +__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple& idx, + const Shape& shape) +{ + // Pack each value in tuple to remove empty tuples after generation + auto transforms = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i)); + } + else if constexpr(is_detected>>::value) + { + + const auto from = idx.At(num_i).from_; + const auto dim = size(shape); + const auto range = idx.At(num_i).range(dim); + return make_slice_transform(range, from, from + range); + } + else + { + // remove dimension for just value + return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple(transforms); +} + +template +__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) +{ + // There is no output for Freeze transform + return Sequence<>{}; +} + +template +__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice&) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) +{ + return Tuple<>{}; +} + +template +__host__ __device__ constexpr auto GenerateUpperDims(const Tuple& transforms) +{ + constexpr auto num_transforms = Tuple::Size(); + // Deduce Sequence element for specific transform + const auto current_elem = GetSequenceVal(transforms.At(Number<0>{})); + if constexpr(is_same_v>) + { + const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(current_elem), next_tuple); + } + else + { + // Increase i if current_elem is Slice transform + const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(current_elem), next_tuple); + } +} + +template +__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& idx, + const Shape& shape, + const FlattenDescriptor& flatten_desc) +{ + constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); + + const auto transforms = GenerateSliceTransforms(idx, shape); + using TransformsTupleType = decltype(transforms); + + const auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; + return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); +} +} // namespace +} // namespace detail + /** * \brief Tensor wrapper that performs static and dynamic buffer logic. + * The tensor is based on a descriptor stored in the Layout. Additionally, + * tensor can be sliced or shifted using multi-index offset. * * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam ElementType Element data type. * \tparam Shape Tensor shape (layout component). - * \tparam UnnestedDescriptorType Unnested descriptor (layout component). - * \tparam NumVectors Number of vectors (only for VGPR, SGPR). - * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). + * \tparam UnrolledDescriptorType Flatten descriptor (layout component). */ template + typename UnrolledDescriptorType> struct Tensor { - private: - // Check if Tuple contains Slice object - template - __host__ __device__ constexpr static bool IsSlicing(T&&) - { - return is_detected::value; - } - template - __host__ __device__ constexpr static bool IsSlicing(Tuple&&) - { - return (IsSlicing(Ts{}) || ...); - } - - // Calculate new tensor shape after slice - template - __host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, - const ShapeTmpType& shape) const - { - // Pack each value in tuple to remove empty tuples after generation - auto new_shape = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - if constexpr(!IsSlicing(tuple_element_t>{})) - { - // if tuple does not have any slice then we can remove dimension - return Tuple<>{}; - } - else - { - // if tuple then recurrence - return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i))); - } - } - else if constexpr(is_detected>>::value) - { - // calculate new dimension - const auto& dim = size(shape.At(num_i)); - const auto val = idx.At(num_i).range(dim); - return make_tuple(val); - } - else - { - // remove dimension for just value - return Tuple<>{}; - } - }, - Number::Size()>{}); - // Remove empty tuples (deleted elements) and return - return UnrollNestedTuple<0, 1>(new_shape); - } - - // Generate Freeze for each of nested shape - template - __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, - const ShapeTmpType& shape) const - { - const auto unrolled_shape = UnrollNestedTuple(shape); - return generate_tuple( - [&](auto i) { - // dimension offset from idx - const auto dim = unrolled_shape.At(Number{}); - const auto dim_idx = idx % dim; - idx /= dim; - return make_freeze_transform(dim_idx); - }, - Number{}); - } - - template - __host__ __device__ constexpr auto - GetTransformsFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape) const - { - // Pack each value in tuple to remove empty tuples after generation - auto transforms = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i)); - } - else if constexpr(is_detected>>::value) - { - - const auto from = idx.At(num_i).from_; - const auto dim = shape.At(num_i); - const auto range = idx.At(num_i).range(dim); - return make_slice_transform(range, from, from + range); - } - else - { - // remove dimension for just value - return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); - } - }, - Number::Size()>{}); - // Remove empty tuples (deleted elements) and return - return UnrollNestedTuple(transforms); - } - - // There is no output for Freeze transform - template - __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) const - { - return Sequence<>{}; - } - - template - __host__ __device__ constexpr auto - GetSequenceVal(const ck::Slice&) const - { - return Sequence{}; - } - - template - __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const - { - return Tuple<>{}; - } - - template - __host__ __device__ constexpr auto - GenerateUpperDims(const Tuple& transforms) const - { - constexpr auto num_transforms = Tuple::Size(); - // Deduce Sequence element for specific transform - const auto currect_elem = GetSequenceVal(transforms.At(Number<0>{})); - if constexpr(is_same_v>) - { - const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); - return concat_tuple(make_tuple(currect_elem), next_tuple); - } - else - { - // Increase i if current_elem is Slice transform - const auto next_tuple = - GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); - return concat_tuple(make_tuple(currect_elem), next_tuple); - } - } - - template - __host__ __device__ constexpr auto - GetDescriptorFromSlicedTensor(const Tuple& idx, - const ShapeTmpType& shape, - const FlattenDescriptor& flatten_desc) const - { - constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); - - const auto transforms = GetTransformsFromSlicedTensor(idx, shape); - using TransformsTupleType = decltype(transforms); - - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; - return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); - } - public: - using ElementSpaceSize = decltype(Layout{ - Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer + using ElementSpaceSize = decltype(Layout{ + Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer using TensorElementType = ElementType; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; @@ -200,134 +216,207 @@ struct Tensor BufferAddressSpace == MemoryTypeEnum ::Vgpr); __host__ __device__ Tensor() = delete; - __host__ __device__ Tensor(ElementType* pointer, - const Layout& layout) + __host__ __device__ constexpr Tensor(ElementType* pointer, + const Layout& layout) : layout_(layout), - buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())) + buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())), + multi_idx_offset_(make_zero_multi_index()), + base_offset_(0) { + static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } - __host__ __device__ Tensor(const Layout& layout) - : layout_(layout) + __host__ __device__ constexpr Tensor(const Layout& layout) + : layout_(layout), + multi_idx_offset_(make_zero_multi_index()), + base_offset_(0) { static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } - __host__ __device__ constexpr const Layout& GetLayout() const + __host__ __device__ constexpr const Layout& GetLayout() const { return layout_; } - // Getter for new sliced tensor - template {}), bool> = false> - __host__ __device__ auto operator[](const Tuple& idx) const + /** + * \brief Get the new sliced tensor. + * + * \param idx Tuple of indices: slice(from,to) or scalar. + * \return Sliced tensor. + */ + template {}), bool> = false> + __host__ __device__ auto operator[](const Tuple& idx) { static_assert(IsDynamicBuffer, "Register slice is not supported"); const auto& shape = layout_.GetShape(); - auto new_shape = GetShapeFromSlicedTensor(idx, shape); + auto new_shape = detail::GetSlicedShape(idx, shape); - const auto& flatten_desc = layout_.GetUnnestedDescriptor(); - auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc); + const auto& flatten_desc = layout_.GetUnrolledDescriptor(); + auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc); const auto new_layout = Layout(new_shape, new_desc); + // Update embed offset + base_offset_ -= new_layout(make_tuple(Number<0>{})); return make_tensor(buffer_.p_data_, new_layout); } - template {}), bool> = false> - __host__ __device__ auto operator()(const Tuple& idx) const + template {}), bool> = false> + __host__ __device__ auto operator()(const Tuple& idx) { return this->operator[](idx); } - template {}), bool> = false> - __host__ __device__ auto operator()(Idxs... idxs) const + template {}), bool> = false> + __host__ __device__ auto operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } - // Getter for the const value - template {}), bool> = false> + /** + * \brief Getter of the tensor's const value reference. + * + * \param idx Tuple of indices. + * \return Requested value. + */ + template {}), bool> = false> __host__ __device__ const ElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { - const index_t offset = layout_(idx); + const index_t offset = layout_(idx) + base_offset_; return buffer_[offset]; } else { - constexpr index_t offset = Layout{ + constexpr index_t index_offset = Layout{ Shape{}, - UnnestedDescriptorType{}}.template operator()>(); - return buffer_[Number{}]; + UnrolledDescriptorType{}}.template operator()>(); + // Calculate and apply base offset in compile-time + constexpr index_t base_offset = Layout{ + Shape{}, + UnrolledDescriptorType{}}.template operator()>(); + return buffer_[Number{}]; } } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ const ElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ const ElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } - // Getter for the value reference - template {}), bool> = false> + /** + * \brief Getter of tensor value reference. + * + * \param idx Tuple of indices. + * \return Requested value. + */ + template {}), bool> = false> __host__ __device__ ElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { - const index_t offset = layout_(idx); + const index_t offset = layout_(idx) + base_offset_; return buffer_(offset); } else { - constexpr index_t offset = Layout{ + constexpr index_t index_offset = Layout{ Shape{}, - UnnestedDescriptorType{}}.template operator()>(); - return buffer_(Number{}); + UnrolledDescriptorType{}}.template operator()>(); + // Apply embed offset (calculate in compiletime) + constexpr index_t base_offset = Layout{ + Shape{}, + UnrolledDescriptorType{}}.template operator()>(); + return buffer_(Number{}); } } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ ElementType& operator()(const Tuple& idx) { return this->operator[](idx); } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ ElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } - __host__ __device__ constexpr auto GetDefaultDescriptor() + /** + * \brief Get descriptor with all nested dimensions merged. + * + * \return Merged nests descriptor. + */ + __host__ __device__ constexpr auto GetMergedNestingDescriptor() { - return layout_.GetDefaultDescriptor(); + return layout_.GetMergedNestingDescriptor(); } + /** + * \brief Get pointer to the data. + * + * \return Pointer. + */ __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } + __host__ __device__ constexpr auto& GetBuffer() { return buffer_; } + __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } + + /** + * \brief Get multi index offset to the data. + * + * \return Multi index offset. + */ + __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; } + + /** + * \brief Apply multi index offset on the tensor. + * + * \param multi_idx_offset Multi index offset. + */ + template + __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset) + { + multi_idx_offset_ = multi_idx_offset; + base_offset_ += layout_(multi_idx_offset); + } + private: using DynamicBufferType = DynamicBuffer; - using StaticBufferType = - StaticBufferTupleOfVector; + using StaticBufferType = StaticBuffer; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; - const Layout layout_; + const Layout layout_; Buffer buffer_; + // We use multi_idx_offset_ to enable the creation of a descriptor in + // compile time for partitions or tiles if tile shape and thread layout + // is known at compile time (We can use the same descriptor for each + // thread). Additionally, the copy between the static and dynamic buffer + // requires a descriptor known at compile time, so we can shift data using + // such multi_idx_offset_. + MultiIndex multi_idx_offset_; + // Base offset and multi index offset are corresponding to exactly the + // same element in tensor ( and in physical memory ). Multi index offset + // is multi dimensional index. However base offset is calculated using + // tensor descriptor (thus all it's transforms) and is linear (1D). + // We store base_offset_ to avoid multiple recalculations. + index_t base_offset_; }; } // namespace wrapper diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index f4ba0a969f..d04bd5078b 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -22,14 +22,19 @@ namespace wrapper { // Disable from doxygen docs generation /// @cond // forward declaration -template +template struct Layout; template using is_tuple = decltype(std::declval().IsTuple()); namespace { -// Generate packed (column-major) strides if not passed +/** + * \brief Generate packed (column-major) strides if not passed + * + * \param shape Tensor shape. + * \return Generated column-major strides. + */ template __host__ __device__ constexpr static auto GenerateColumnMajorPackedStrides(const Tuple& shape) @@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple& shape) Number{}); } +/** + * \brief Create naive tensor descriptor from nested shape. + * + * \param shape Tensor shape. + * \param strides Tensor strides. + * \return Unrolled descriptor + */ template -__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape, - const LayoutStrides& strides) +__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) { const auto unrolled_shape = UnrollNestedTuple(shape); if constexpr(is_same_v>) @@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap template __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) { - using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{})); - return Layout(shape, MakeFlattenDescriptor(shape, strides)); + using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); + return Layout(shape, MakeUnrolledDescriptor(shape, strides)); } /** @@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides template __host__ __device__ constexpr auto make_layout(const Shape& shape) { - using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{})); - return Layout(shape, MakeFlattenDescriptor(shape, Tuple<>{})); + using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); + return Layout(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); } // Layout helpers // get -// Get dim (could be returned from get with empty Idxs) + /** * \private + * \brief Get dim. + * + * \param dim Dimension. + * \return Returned the same dimension. */ template __host__ __device__ T constexpr get(const T& dim) @@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout& layout) }, Number{}); - const auto& flatten_desc = layout.GetUnnestedDescriptor(); + const auto& flatten_desc = layout.GetUnrolledDescriptor(); auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); return Layout(new_shape, new_desc); } @@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem) } // size -// Get dim size (could be returned from get function) /** * \private + * \brief Get size. + * + * \param dim Size. + * \return Returned the same size. */ template __host__ __device__ T constexpr size(const T& dim) @@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim) * \param layout Layout to get Shape of. * \return Requsted length. */ -template -__host__ __device__ constexpr auto size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.template GetLength(); } @@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple& shape) * \param layout Layout to calculate shape size. * \return Requsted size. */ -template -__host__ __device__ constexpr auto size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.GetLengths(); } @@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem) * \param layout Layout to calculate rank. * \return Requsted rank. */ -template +template __host__ __device__ constexpr auto -rank([[maybe_unused]] const Layout& layout) +rank([[maybe_unused]] const Layout& layout) { return Shape::Size(); } @@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple& t /** * \private + * \brief Rank for scalar + * + * \param dim Dimension scalar. + * \return Returned 1. */ template -__host__ __device__ constexpr index_t rank(const Number&) +__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number& dim) { return 1; } /** * \private + * \brief Rank for scalar + * + * \param dim Dimension scalar. + * \return Returned 1. */ -__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } +__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; } /** * \brief Hierarchical rank. @@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem) * \param layout Layout to calculate depth. * \return Requsted depth. */ -template -__host__ __device__ constexpr auto depth(const Layout& layout) +template +__host__ __device__ constexpr auto depth(const Layout& layout) { const auto& shape = layout.GetShape(); return TupleDepth(shape); @@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple& tuple) /** * \private + * \brief Depth for scalar + * + * \param dim Scalar. + * \return Returned 0. */ template -__host__ __device__ constexpr index_t depth(const Number&) +__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number& dim) { return 0; } /** * \private + * \brief Depth for scalar + * + * \param dim Scalar. + * \return Returned 0. */ -__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } +__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; } /** * \brief Hierarchical depth. diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index a0634f6b38..6aae5a92fe 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,12 +6,22 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" + namespace ck { namespace wrapper { namespace { -// Calculate shape for partition based on number of threads per each dim and -// previous shape + +/** + * \brief Calculate shape for partition based on number of threads per each dim and + * previous shape + * + * \param shape Base tensor shape. + * \param thread_lengths Tuple of thread lengths. + * \return Partition shape. + */ template __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple& shape, const Tuple& thread_lengths) @@ -20,265 +30,165 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i)); - } - else - { - const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i); - return slice_len; - } - }, - Number::Size()>{}); -} - -// Calculate shape for partition based on number of threads per each dim, -// previous strides and steps -template -__host__ __device__ constexpr auto -CalculateLocalPartitionDescriptor(const Tuple& shape, - const Tuple& thread_lengths, - const Tuple& steps, - const FlattenDescType& flatten_desc) -{ - - static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); - const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths); - const auto unrolled_shape = UnrollNestedTuple(shape); - constexpr auto dims = decltype(unrolled_thread_lengths)::Size(); - - using UnrolledStepsType = decltype(UnrollNestedTuple(steps)); - - using I1 = Number<1>; - - const auto transforms = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_same_v, Tuple<>>) - { - // By default raked partition - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else if constexpr(!is_same_v, index_t>) - { - // Compiletime partition - if constexpr(is_same_v, I1>) - { - // raked - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else - { - // packed - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(I1{})); - } - } - else - { - // Runtime partition - if(steps.At(num_i) == 1) - { - // raked - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else - { - // packed - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(I1{})); - } - } - }, - Number{}); - - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - const auto upper_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); -} - -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple& thread_lengths, - const Tuple& steps, - index_t& thread_id) -{ - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - if constexpr(is_same_v, Tuple<>>) - { - return CalculateLayoutOffsetIdxImpl( - thread_lengths.At(num_i), Tuple<>{}, thread_id); - } - else - { - return CalculateLayoutOffsetIdxImpl( - thread_lengths.At(num_i), steps.At(num_i), thread_id); - } - } - else - { - // Update thread_id after each dim - const auto dim_thread_id = thread_id % thread_lengths.At(num_i); - thread_id /= thread_lengths.At(num_i); - if constexpr(is_same_v, Tuple<>>) - { - return dim_thread_id; - } - else - { - // Apply step - return steps.At(num_i) * dim_thread_id; - } - } + const auto slice_len = size(shape) / thread_lengths.At(num_i); + return slice_len; }, Number::Size()>{}); } -// Convert integer thread_idx to tuple index with steps applied -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& thread_lengths, - const Tuple& steps, - const index_t thread_id) +/** + * \brief Calculate total number of blocks. + * + * \param shape Base tensor shape. + * \param tile_shape Tile shape. + * \return Tuple with blocks number. + */ +template +__host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, + const Tuple& tile_shape) { - // Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates - index_t thread_id_copy = thread_id; - return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy); + static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); + return generate_tuple([&](auto i) { return size(shape) / size(tile_shape); }, + Number::Size()>{}); } -// Apply steps to index represented as tuple -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& steps, - const Tuple& block_idxs) +/** + * \brief Calculate scaled offset for new partition/tile. + * + * \param thread_idxs Thread 1d id. + * \param partition_lengths_seq Sequence of partition shape. + * \param old_offset_idxs Multi index offset from base tensor to shift values. + * \return Partition shape. + */ +template +__host__ __device__ constexpr auto +CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, + const PartitionLengthsSeq& partition_lengths_seq, + const OldOffsetIdxs& old_offset_idxs) { - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - if constexpr(is_same_v, Tuple<>>) - { - return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i)); - } - else - { - return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i)); - } - } - else - { - if constexpr(is_same_v, Tuple<>>) - { - return block_idxs.At(num_i); - } - else - { - // apply step - return steps.At(num_i) * block_idxs.At(num_i); - } - } - }, - Number::Size()>{}); + return thread_idxs * partition_lengths_seq + old_offset_idxs; } -// User passes only shape per block to the make_local_tile function. This function calculates -// block layout based on the shape. -template -__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple& shape, - const Tuple& tile_shape) -{ - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i)); - } - else - { - return shape.At(num_i) / tile_shape.At(num_i); - } - }, - Number::Size()>{}); -} } // namespace /** - * \brief Create local partition for thread. + * \brief Create local partition for thread (At now only packed partition + * is supported). * * \param tensor Tensor for partition. - * \param thread_lengths Layout of threads. + * \param thread_lengths Layout of threads (could not be nested). * \param thread_id Thread index represented as integer. - * \param steps Thread step (default=1, raked partition) * \return Partition tensor. */ -template > -__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor, - const ThreadLengthsTuple& thread_lengths, - const index_t thread_id, - const StepsTuple steps = StepsTuple{}) +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + const index_t thread_id) { - // Create shape, strides and layout for new partition tensor - const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths); - // Create new descriptor and layout - const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); - auto partition_desc = - CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc); - const auto partition_layout = Layout( - partition_shape, partition_desc); - // Calculate offset for new partition tensor - const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id); - const auto partition_offset = layout(tensor)(offset_idx); - return make_tensor(tensor.GetPointer() + partition_offset, - partition_layout); + static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + // Calculate new partition shape + const auto& tensor_shape = shape(tensor); + constexpr auto partition_shape = + CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{}); + // Create Thread Cluster Descriptor + constexpr auto partition_lengths_seq = generate_sequence_v2( + [&](auto I) { return size(partition_shape); }, Number{}); + constexpr auto thread_lengths_seq = + generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, + Number{}); + constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); + // Calculate thread idxs and offsets + const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + const auto offset_multi_idxs = + CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + auto& flatten_desc = layout(tensor).GetUnrolledDescriptor(); + const auto partition_layout = + Layout, decltype(flatten_desc)>( + partition_shape, flatten_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), partition_layout); + // Apply offsets + partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return partition_tensor; } /** - * \brief Create local tile for thread block. + * \brief Create local tile for thread block. (At now only packed tile + * is supported). + * + * \note Temporary to gain the best performance use 2d + * tile_shape. + * * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_idx Block index represented as tuple. - * \param steps Block step (default=1, raked partition) + * \param block_id Block index represented as integer. + * \return Tile tensor. */ -template > -__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, - const BlockShapeTuple& tile_shape, - const BlockIdxTuple& block_idx, - const StepsTuple steps = StepsTuple{}) +template +__host__ __device__ constexpr auto +make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) { - // Create block lengths, strides and layout for new tile tensor - const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape); - // Create new descriptor and layout - const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); - auto tile_desc = - CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc); - const auto tile_layout = Layout, decltype(tile_desc)>( - tile_shape, tile_desc); - // Calculate offset for new partition tensor - const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx); - const auto tile_offset = layout(tensor)(offset_idx); - return make_tensor(tensor.GetPointer() + tile_offset, - tile_layout); + static_assert(!IsNestedTuple(BlockShapeTuple{})); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); + + if constexpr(BlockShapeTuple::Size() == I2) + { + // Optimized version for 2d tile shape [MxK] + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt>(aligned_desc); + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape)); + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape)); + const auto offset_multi_idxs = + make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid); + // Create new layout and tensor + const auto tile_layout = + Layout, decltype(aligned_desc)>(tile_shape, + aligned_desc); + auto tile_tensor = + make_tensor(tensor.GetPointer(), tile_layout); + // Apply offsets + tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return tile_tensor; + } + else + { + // Calculate offsets + // Sequence with data to process per block + constexpr auto tile_shape_seq = + generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); }, + Number{}); + // Tuple with number of blocks + const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape); + constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); + const auto block_idxs = + block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); + const auto offset_multi_idxs = + CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + const auto tile_layout = + Layout, decltype(aligned_desc)>(tile_shape, + aligned_desc); + auto tile_tensor = + make_tensor(tensor.GetPointer(), tile_layout); + // Apply offsets + tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return tile_tensor; + } } } // namespace wrapper diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index 1e932e62e1..7ec080760a 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,6 +10,7 @@ #include "ck/utility/tuple_helper.hpp" #include "ck/utility/dynamic_buffer.hpp" #include "ck/utility/amd_address_space.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { namespace wrapper { @@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation /// @cond // forward declarations -template +template struct Layout; template - + typename UnrolledDescriptorType> struct Tensor; template @@ -45,13 +42,22 @@ struct Slice __host__ __device__ constexpr Slice() : from_(), to_() {} __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {} + /** + * \brief Calculate slice range. + * + * \param dim Dimension size. + * \return Slice range. + */ template __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || is_same_v) { - assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range"); + if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_))) + { + throw std::runtime_error("Invalid range"); + } if(to_ < 0) { return dim - from_ + to_ + 1; @@ -101,40 +107,27 @@ using is_tuple = decltype(std::declval().IsTuple()); template + typename UnrolledDescriptorType> constexpr auto make_tensor(ElementType* pointer, - const Layout& layout) + const Layout& layout) { - return Tensor(pointer, layout); + return Tensor(pointer, layout); } /** * \brief Make SGPR or VGPR tensor function. * * \tparam MemoryType Type of memory. - * \tparam NumVectors Number of vectors. - * \tparam ScalarPerVector Scalars per vector. * \tparam ElementType Memory data type. * \return Constructed tensor. */ template -constexpr auto make_register_tensor() + typename ElementType, + typename Shape, + typename UnrolledDescriptorType> +constexpr auto make_register_tensor(const Layout& layout) { - const auto layout = make_layout(make_tuple(Number{}), make_tuple(Number<1>{})); - return Tensor>, - std::remove_const_t>, - NumVectors, - ScalarPerVector>(layout); + return Tensor(layout); } /** @@ -146,15 +139,9 @@ constexpr auto make_register_tensor() template -__host__ __device__ constexpr const auto& layout(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr const auto& +layout(const Tensor& tensor) { return tensor.GetLayout(); } @@ -170,15 +157,9 @@ template -__host__ __device__ constexpr auto size(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +size(const Tensor& tensor) { return size(tensor.GetLayout()); } @@ -194,15 +175,9 @@ template -__host__ __device__ constexpr auto rank(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +rank(const Tensor& tensor) { return rank(tensor.GetLayout()); } @@ -218,15 +193,9 @@ template -__host__ __device__ constexpr auto depth(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +depth(const Tensor& tensor) { return depth(tensor.GetLayout()); } @@ -240,15 +209,9 @@ __host__ __device__ constexpr auto depth(const Tensor -__host__ __device__ constexpr const auto& shape(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr const auto& +shape(const Tensor& tensor) { return shape(tensor.GetLayout()); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp index 45e35ec56d..5f2ab12164 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp @@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator return 0; } + throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 50040a2441..bfb8b48187 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator return 0; } + throw std::runtime_error( + "Conv_bwd_data: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index 02ad7a033a..d0b98efd1f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator return 0; } + throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index ffc9470df2..d63b5256f9 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator return 0; } + throw std::runtime_error("Conv_fwd: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 6e39dee71c..4d52563f42 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType v_acc = 0; + ComputeTypeA v_a = 0; + ComputeTypeB v_b = 0; for(int k = 0; k < K; ++k) { - ComputeTypeA v_a; - ComputeTypeB v_b; - // use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v) @@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } - CDataType v_c; + CDataType v_c = 0; arg.c_element_op_(v_c, v_acc); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp index 56b0ce7914..4682c5c223 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,6 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/numeric.hpp" namespace ck { namespace tensor_operation { @@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator return 0; } + throw std::runtime_error("Img2Col: number of dimensions should be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index 77ad36b97b..42ca8e755d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory< return op_ptrs; } }; - +#endif } // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp index 2df378b0c6..730785f702 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp @@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_groupnorm_bwd_gamma_beta_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta> +{ + using DeviceOp = DeviceNormalizationBwdGammaBeta; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_groupnorm_bwd_gamma_beta_f32_instances(op_ptrs); + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp new file mode 100644 index 0000000000..e2736ac77e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +// FP16 +void add_device_layernorm2d_bwd_gamma_beta_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_layernorm2d_bwd_gamma_beta_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta> +{ + using DeviceOp = DeviceNormalizationBwdGammaBeta; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_gamma_beta_f16_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_gamma_beta_f32_instances(op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 06a117919a..d547b3e602 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 48351b2f29..60d4ccf525 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index ad846e4c80..4a2526b3a4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 3c50cf2273..01e0ebdb34 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 9fd83cdec8..45096f659f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -27,6 +27,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; @@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format on >; -template +template using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> // clang-format on >; @@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); - add_device_operation_instances( - instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 25c94bb886..b22f4a3beb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -27,6 +27,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; @@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> // clang-format on >; +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector{}); + + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp index aa399f56ec..160bcb4ace 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances( +void add_device_layernorm2d_bwd_gamma_beta_f16_instances( std::vector>>& instances) { diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp index ba2966ba37..6f42eca0b9 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances( +void add_device_layernorm2d_bwd_gamma_beta_f32_instances( std::vector>>& instances) { diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 586a356ecc..0419ccd8e7 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -298,7 +298,7 @@ int profile_gemm_impl(int do_verification, } } - return pass ? 0 : 1; + return pass; } } // namespace profiler diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index 6816d2c538..5d5ae1ad15 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -145,7 +145,7 @@ bool profile_gemm_splitk_impl(int do_verification, // profile device GEMM instances for(auto& op_ptr : op_ptrs) { - std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 32, 36, 40, 64, 96, 128}; + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38}; if(KBatch > 0) { diff --git a/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp b/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000..5e9d3df1b1 --- /dev/null +++ b/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_groupnorm_bwd_gamma_beta_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need GammaDataType and DXDataType here, just for reference class + using GammaDataType = DYDataType; + using DXDataType = DYDataType; + + if(length.size() != 5) + return false; + + index_t N = length[0]; + index_t G = length[3]; + index_t C = length[4]; + + std::vector reduce_dim = {0, 1, 2}; + std::vector gamma_beta_length = {G, C}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma(gamma_beta_length); // dummy tensor, for reference + Tensor mean({N, G}); + Tensor inv_std({N, G}); + Tensor dgamma(gamma_beta_length); + Tensor dbeta(gamma_beta_length); + + Tensor host_dx(length); // dummy tensor, for reference + Tensor host_dgamma(gamma_beta_length); + Tensor host_dbeta(gamma_beta_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + + std::vector strideDGamma{dgamma.mDesc.GetStrides().begin(), + dgamma.mDesc.GetStrides().end()}; + + std::vector strideDBeta{dbeta.mDesc.GetStrides().begin(), + dbeta.mDesc.GetStrides().end()}; + + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dgamma.GenerateTensorValue(GeneratorTensor_1{}); + dbeta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{0, 5}); + dgamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dbeta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0, 0.5}); + dgamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dbeta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) + + dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType); + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + gamma_beta_length, + strideDGamma, + strideDBeta, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + bool pass = + ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dgamma : ", host_dgamma.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp b/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000..10fa9c86d5 --- /dev/null +++ b/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_layernorm_bwd_gamma_beta_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need GammaDataType and DXDataType here, just for reference class + using GammaDataType = DYDataType; + using DXDataType = DYDataType; + + if(length.size() != Rank || Rank < 2) + return false; + + // Assume normalize dimension for first dimension + // Layernorm 2D, input = [M, K], reduce on M axis + // Layernorm 4D, input = [N, H, W, C], redice on N axis + constexpr int NumReduceDim = Rank - 1; + + std::vector reduce_dim = {0}; + std::vector invarient_length{length.begin() + 1, length.end()}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma(invarient_length); // dummy tensor, for reference + Tensor mean({length[0]}); + Tensor inv_std({length[0]}); + Tensor dgamma(invarient_length); + Tensor dbeta(invarient_length); + + Tensor host_dx(length); // dummy tensor, for reference + Tensor host_dgamma(invarient_length); + Tensor host_dbeta(invarient_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + + std::vector strideDGamma{dgamma.mDesc.GetStrides().begin(), + dgamma.mDesc.GetStrides().end()}; + + std::vector strideDBeta{dbeta.mDesc.GetStrides().begin(), + dbeta.mDesc.GetStrides().end()}; + + std::vector strideMeanInvStd{Rank, 0}; + strideMeanInvStd[0] = 1; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dgamma.GenerateTensorValue(GeneratorTensor_1{}); + dbeta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{0, 5}); + dgamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dbeta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0, 0.5}); + dgamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dbeta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) + + dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType); + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + invarient_length, + strideDGamma, + strideDBeta, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + bool pass = + ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dgamma : ", host_dgamma.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 68ef04ed11..e9cf6eecfb 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -19,6 +19,8 @@ set(PROFILER_SOURCES profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp profile_layernorm_bwd_data.cpp + profile_layernorm_bwd_gamma_beta.cpp + profile_groupnorm_bwd_gamma_beta.cpp profile_layernorm_fwd.cpp profile_max_pool3d_fwd.cpp profile_avg_pool3d_bwd.cpp @@ -82,6 +84,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 0d6c5021f3..c322c7054b 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -137,8 +137,14 @@ int profile_gemm(int argc, char* argv[]) return pass ? 0 : 1; }; - if(false) - ; + if(data_type != GemmDataType::F32_F32_F32 && data_type != GemmDataType::F16_F16_F16 && + data_type != GemmDataType::BF16_BF16_BF16 && data_type != GemmDataType::INT8_INT8_INT8 && + data_type != GemmDataType::F8_F8_F8) + { + // dummy clause before the else clauses for different data types + std::cout << "Gemm: this data_type is not implemented" << std::endl; + return 1; + } #ifdef CK_ENABLE_FP32 else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -231,7 +237,7 @@ int profile_gemm(int argc, char* argv[]) #endif else { - std::cout << "this data_type & layout is not implemented" << std::endl; + std::cout << "Gemm: this data_type & layout is not implemented" << std::endl; return 1; } diff --git a/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp b/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..7fcef3a4e2 --- /dev/null +++ b/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct groupnormBwdGammaBetaArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_groupnorm_bwd_gamma_beta() +{ + // eg: ckProfiler groupnorm_bwd_gamma_beta 1 0 2 0 1 --length 1 16 16 32 40 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n" + << std::endl; +} + +int profile_groupnorm_bwd_gamma_beta(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_groupnorm_bwd_gamma_beta(); + return 0; + } + + groupnormBwdGammaBetaArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F32 = float; + + if(length.size() == 5) + { + if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_groupnorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("length should be 5"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("groupnorm_bwd_gamma_beta", + "Group Normalization", + profile_groupnorm_bwd_gamma_beta); diff --git a/profiler/src/profile_layernorm_bwd_gamma_beta.cpp b/profiler/src/profile_layernorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..0f3436c663 --- /dev/null +++ b/profiler/src/profile_layernorm_bwd_gamma_beta.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct layernormBwdGammaBetaArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_layernorm_bwd_gamma_beta() +{ + // eg: ckProfiler layernorm_bwd_gamma_beta 0 0 2 0 1 --length 1502 4096 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1024 1024) \n" + << std::endl; +} + +int profile_layernorm_bwd_gamma_beta(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_layernorm_bwd_gamma_beta(); + return 0; + } + + layernormBwdGammaBetaArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F16 = ck::half_t; + using F32 = float; + + if(length.size() == 2) + { + constexpr int rank = 2; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("layernorm_bwd_gamma_beta", + "Layer Normalization", + profile_layernorm_bwd_gamma_beta); diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index da83254f00..728b8c1092 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 90140659f6..fa5f8583af 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization_fwd) add_subdirectory(normalization_bwd_data) +add_subdirectory(normalization_bwd_gamma_beta) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) diff --git a/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp b/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp index 67c0f2698c..df8b77aba1 100644 --- a/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp +++ b/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp @@ -135,6 +135,8 @@ class TestConvTensorRearrangeInterface : public ::testing::Test return col2img.IsSupportedArgument(argument); } + throw std::runtime_error("Conv_tensor_rearrange: problem with tensor rearrange operator. "); + return 1; } }; diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt new file mode 100644 index 0000000000..f3579aad08 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_normalization_bwd_gamma_beta) +add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) +endif() + +add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) +endif() + diff --git a/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp b/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp new file mode 100644 index 0000000000..ab9cb29891 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestgroupnormBwdGammaBeta : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using DGammaDataType = std::tuple_element_t<4, Tuple>; + using DBetaDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, H, W, G, C], reduce H, W, C + std::vector> lengths = {{1, 1, 1, 1, 1}, + {1, 2, 3, 4, 5}, + {256, 9, 9, 9, 9}, + {1, 64, 64, 32, 10}, + {1, 32, 32, 32, 20}, + {1, 16, 16, 32, 40}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_groupnorm_bwd_gamma_beta_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestgroupnormBwdGammaBeta, KernelTypes); +TYPED_TEST(TestgroupnormBwdGammaBeta, Test_FP32) { this->Run(); } diff --git a/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp b/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp new file mode 100644 index 0000000000..53c92413b1 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestLayernorm2dBwdGammaBeta : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using DGammaDataType = std::tuple_element_t<4, Tuple>; + using DBetaDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2dBwdGammaBeta, KernelTypes); +TYPED_TEST(TestLayernorm2dBwdGammaBeta, Test_FP32) { this->Run(); } diff --git a/test/wrapper/test_copy.cpp b/test/wrapper/test_copy.cpp index 5cf09a54be..e7fa3c539b 100644 --- a/test/wrapper/test_copy.cpp +++ b/test/wrapper/test_copy.cpp @@ -21,49 +21,59 @@ template + bool UseOptimizedCopy> __global__ void TestCopyDevice(const InputTensor input_tensor, OutputTensor output_tensor, const BlockShape tile_shape, - const ThreadLayoutShape thread_layout, - const LocalTileSteps block_steps, - const LocalPartitionSteps thread_steps) + const ThreadLayoutShape thread_layout) { __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; - auto tensor_lds = ck::wrapper::make_tensor( + const auto tensor_lds = ck::wrapper::make_tensor( p_shared, ck::wrapper::make_layout(tile_shape)); - const auto block_idxs = ck::make_tuple(ck::make_tuple(0, 0), blockIdx.x); + const auto block_idx = static_cast(blockIdx.x); // Get local tiles for global memory - const auto input_local_tile = - ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs, block_steps); + const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); const auto output_local_tile = - ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs, block_steps); + ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); // Get partition per thread - const auto input_local_partition = ck::wrapper::make_local_partition( - input_local_tile, thread_layout, threadIdx.x, thread_steps); + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); auto lds_local_partition = - ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x, thread_steps); - auto output_local_partition = ck::wrapper::make_local_partition( - output_local_tile, thread_layout, threadIdx.x, thread_steps); + ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); // Allocate VGPR - constexpr ck::index_t scalar_per_vector = 1; - constexpr ck::index_t vgpr_size = ck::wrapper::size(lds_local_partition); - auto tensor_vgpr = ck::wrapper::make_register_tensor(); + auto tensor_vgpr = + ck::wrapper::make_register_tensor( + layout(lds_local_partition)); // Perform copy - ck::wrapper::copy(input_local_partition, lds_local_partition); - ck::wrapper::copy(lds_local_partition, tensor_vgpr); - ck::wrapper::copy(tensor_vgpr, output_local_partition); + if constexpr(UseOptimizedCopy) + { + using DimAccessOrder = ck::Tuple, ck::Number<0>>; + constexpr ck::index_t vector_dim = 0; + constexpr ck::index_t scalar_per_vector = 2; + ck::wrapper::copy(input_local_partition, + lds_local_partition); + // TODO: Enable optimized copy for static buffers + ck::wrapper::copy(lds_local_partition, + tensor_vgpr); + ck::wrapper::copy(tensor_vgpr, + output_local_partition); + } + else + { + ck::wrapper::copy(input_local_partition, lds_local_partition); + ck::wrapper::copy(lds_local_partition, tensor_vgpr); + ck::wrapper::copy(tensor_vgpr, output_local_partition); + } } +template void PerformCopyGlobalToGlobalViaLDS() { const auto shape = @@ -89,15 +99,8 @@ void PerformCopyGlobalToGlobalViaLDS() auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), layout); - const auto thread_layout = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<32>{}); - const auto tile_shape = - ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<64>{}); - - const auto thread_steps = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<2>{}); - const auto block_steps = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<64>{}); + const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}); + const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); const ck::index_t grid_size = ck::math::integer_divide_ceil( ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); @@ -106,8 +109,7 @@ void PerformCopyGlobalToGlobalViaLDS() decltype(output_tensor_global), decltype(tile_shape), decltype(thread_layout), - decltype(block_steps), - decltype(thread_steps)>; + UseOptimizedCopy>; launch_and_time_kernel(StreamConfig{}, kernel, dim3(grid_size), @@ -116,9 +118,7 @@ void PerformCopyGlobalToGlobalViaLDS() input_tensor_global, output_tensor_global, tile_shape, - thread_layout, - block_steps, - thread_steps); + thread_layout); // Verify results std::vector output_data(ck::wrapper::size(shape)); @@ -126,4 +126,5 @@ void PerformCopyGlobalToGlobalViaLDS() EXPECT_TRUE(ck::utils::check_err(output_data, input_data)); } -TEST(TestCopy, CopyGlobalToGlobalViaLDS) { PerformCopyGlobalToGlobalViaLDS(); } +TEST(TestCopyGlobalToGlobalViaLDS, GenericCopy) { PerformCopyGlobalToGlobalViaLDS(); } +TEST(TestCopyGlobalToGlobalViaLDS, OptimizedCopy) { PerformCopyGlobalToGlobalViaLDS(); } diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_partition.cpp index df56b879f6..cacbfe9d88 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_partition.cpp @@ -29,42 +29,29 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = - ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<1>{}), ck::Number<1>{}); - const auto thread_layout = - ck::make_tuple(ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}), ck::Number<1>{}); - - for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) - { - const auto raked_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); - - const auto expected_partition_size = - ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - EXPECT_EQ(ck::wrapper::size(raked_partition), expected_partition_size); - EXPECT_EQ(raked_partition(0), thread_id); - } + const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); + const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) { const auto packed_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_steps); + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); const auto expected_partition_size = ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - const auto expected_partition_first_val = thread_id * ck::wrapper::size<0, 0>(thread_steps); + const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps); + const auto expected_partition_second_val = expected_partition_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); EXPECT_EQ(packed_partition(0), expected_partition_first_val); + EXPECT_EQ(packed_partition(1), expected_partition_second_val); } } TEST(TestPartition, LocalTile) { - const auto shape = - ck::make_tuple(ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}), ck::Number<4>{}); - const auto strides = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}), ck::Number<64>{}); - const auto layout = ck::wrapper::make_layout(shape, strides); + const auto shape = ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}, ck::Number<4>{}); + const auto strides = ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}, ck::Number<64>{}); + const auto layout = ck::wrapper::make_layout(shape, strides); std::vector data(ck::wrapper::size(layout)); std::iota(data.begin(), data.end(), 0); @@ -72,48 +59,34 @@ TEST(TestPartition, LocalTile) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto block_steps = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); - const auto block_shape = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); - const auto block_layout = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); + const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}); + const auto num_blocks = + ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), + ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), + ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape)); + std::vector block_idxs(ck::wrapper::size(num_blocks)); + std::iota(block_idxs.begin(), block_idxs.end(), 0); - std::vector, ck::index_t>> block_idxs; - for(ck::index_t x = 0; x < ck::wrapper::size<0, 0>(block_layout); x++) + for(auto block_idx : block_idxs) { - for(ck::index_t y = 0; y < ck::wrapper::size<0, 1>(block_layout); y++) - { - for(ck::index_t z = 0; z < ck::wrapper::size<1>(block_layout); z++) - { - block_idxs.emplace_back(ck::make_tuple(x, y), z); - } - } - } - - for(const auto& block_idx : block_idxs) - { - const auto raked_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); + const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); const auto expected_tile_size = ck::wrapper::size(block_shape); - EXPECT_EQ(ck::wrapper::size(raked_tile), expected_tile_size); - EXPECT_EQ(raked_tile(0), layout(block_idx)); - } + auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * + ck::wrapper::size<2>(block_shape) * + ck::wrapper::size<2>(strides); + block_idx /= ck::wrapper::size<2>(num_blocks); + expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) * + ck::wrapper::size<1>(block_shape) * + ck::wrapper::size<1>(strides); + block_idx /= ck::wrapper::size<1>(num_blocks); + expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) * + ck::wrapper::size<0>(block_shape) * + ck::wrapper::size<0>(strides); - for(const auto& block_idx : block_idxs) - { - const auto packed_tile = - ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_steps); - - const auto expected_tile_size = ck::wrapper::size(block_shape); - const auto expected_tile_first_val = - ck::wrapper::size<0, 0>(block_idx) * ck::wrapper::size<0, 0>(block_shape) * - ck::wrapper::size<0, 0>(strides) + - ck::wrapper::size<0, 1>(block_idx) * ck::wrapper::size<0, 1>(block_shape) * - ck::wrapper::size<0, 1>(strides) + - ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) * - ck::wrapper::size<1>(strides); + const auto expected_tile_second_val = expected_tile_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_tile), expected_tile_size); EXPECT_EQ(packed_tile(0), expected_tile_first_val); + EXPECT_EQ(packed_tile(1), expected_tile_second_val); } } diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_tensor.cpp index 2d4d6f2750..3c7d877528 100644 --- a/test/wrapper/test_tensor.cpp +++ b/test/wrapper/test_tensor.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -100,31 +100,26 @@ TEST(TestTensor, ReadWriteHostMemory) __global__ void TestTensorReadWriteDevice(void* data, void* success) { - constexpr ck::index_t nelems = 8; - constexpr ck::index_t scalar_per_vector = 1; + constexpr ck::index_t nelems = 8; __shared__ ck::index_t p_shared[nelems]; ck::index_t* casted_data_ptr = static_cast(data); bool* casted_success_ptr = static_cast(success); const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2)); + constexpr auto vgpr_layout = + ck::wrapper::make_layout(make_tuple(ck::Number{}), make_tuple(ck::Number<1>{})); auto tensor_global = ck::wrapper::make_tensor(casted_data_ptr, layout); - auto tensor_lds = ck::wrapper::make_tensor(p_shared, layout); - auto tensor_vgpr = ck::wrapper::make_register_tensor(); - auto tensor_sgpr = ck::wrapper::make_register_tensor(); + auto tensor_lds = ck::wrapper::make_tensor(p_shared, layout); + auto tensor_vgpr = + ck::wrapper::make_register_tensor( + vgpr_layout); InitTensor(tensor_global); InitTensor(tensor_lds); StaticInitTensor(tensor_vgpr); - StaticInitTensor(tensor_sgpr); *casted_success_ptr = TestTensorCheck1d(tensor_global); *casted_success_ptr &= TestTensorCheck3d(tensor_global); @@ -133,8 +128,6 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success) *casted_success_ptr &= TestTensorCheck3d(tensor_lds); *casted_success_ptr &= StaticTestTensorCheck1d(tensor_vgpr); - - *casted_success_ptr &= StaticTestTensorCheck1d(tensor_sgpr); } TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory)