diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768..c99fc1d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f1bdf8689..356491d9c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,6 +259,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL) + message(STATUS "Enabling XDL FP8 gemms on gfx950") + add_definitions(-DCK_USE_GFX950) + set(CK_USE_GFX950 "ON") +endif() # new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA set(CK_TILE_USE_WMMA 0) diff --git a/Dockerfile.manylinux b/Dockerfile.manylinux new file mode 100644 index 0000000000..0683bcd4a6 --- /dev/null +++ b/Dockerfile.manylinux @@ -0,0 +1,101 @@ +FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=7.2 +ARG compiler_version="" +ARG compiler_commit="" +ARG CK_SCCACHE="" +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive + +USER root + +# Add rocm repository +RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y + +RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \ + dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \ + dnf update -y && \ + dnf install python3-setuptools python3-wheel -y && \ + dnf install rocm-dev -y + +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi + +# Install dependencies +RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \ + cmake \ + clang-tools-extra \ + gcc-c++ \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + git \ + hip-rocclr \ + jq \ + mpich \ + net-tools \ + pkg-config \ + redis \ + sshpass \ + stunnel \ + vim \ + nano \ + zip \ + openssh-server \ + kmod && \ + dnf clean all && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +#Install latest ccache + git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ +#Install latest cppcheck + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ +# Install packages for processing the performance results + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ +# Install the new rocm-cmake version + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + diff --git a/Jenkinsfile b/Jenkinsfile index f3a597e404..1a8be258bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /docker login failed/, description: "Docker login failed"], + [pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /GPU not found/, description: "GPU not found"], + [pattern: /(.*)GPU not found(.*)/, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 5b10edd681..3b3b0fec16 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -19,22 +19,22 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, - PassThrough, PassThrough, PassThrough, GemmDefault, + PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, + 1, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d10..e037be5a18 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantDecodeInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a95ca4862c..37117eaa0f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,6 +93,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigQuantDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -229,6 +250,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 912527c929..ed1709a9ae 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(1.0f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) @@ -659,6 +659,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } + else if(init_method == 3) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + } + } + else if(init_method == 4) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } + else if(init_method == 5) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); + } + // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) + for(ck_tile::index_t row = 0; + row < static_cast(aq_tensor_ptr->get_length(0)); + ++row) + { + for(ck_tile::index_t col = 0; + col < static_cast(aq_tensor_ptr->get_length(1)); + ++col) + { + (*aq_tensor_ptr)(row, col) = static_cast(col + 1); + } + } + // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } else { a_m_k.SetZero(); diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c..3b1ea65695 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm() case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; default: throw "Unknown PipelineVersion"; @@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE."; case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; default: throw "Unknown GridwiseGemmPipelineVersion"; } @@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_pipeline::v3; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: return ck_pipeline::v5; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp index b7df0e4d0e..12482f3206 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -91,6 +91,13 @@ struct TilePipelineType using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; }; +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; +}; + template consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() { @@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp similarity index 82% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/args.hpp index 51edf41cba..eba6771964 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp @@ -7,26 +7,25 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/testing_reflect.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/tensor_buffer.hpp" -#include "ck_tile/host/convolution_parameter.hpp" -#include "ck_tile/builder/testing/tensor_initialization.hpp" #include "ck_tile/builder/testing/tensor_descriptor.hpp" -#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/host/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item -/// of it is the ConvArgs structure - which contains a complete description +/// of it is the Args structure - which contains a complete description /// of a convolution operation. /// /// It is not intended that this file contains implementation details for /// actually launching a convolution operation. As this can be done /// through different APIs depending on the kernel (CK, CK Tile, or a /// reference implementation), the code dealing with that is split out -/// into a separate header for each implementation. +/// into a separate header for each implementation. Nor does this file +/// deal with details for defining the data types (`Inputs` and `Outputs`) +/// for different conv directions, that is also split out into separate +/// headers to keep this one small. namespace ck_tile::builder::test { @@ -56,7 +55,7 @@ struct ConvTensorLengths /// /// @see Args template - requires ValidConvSignature && ConvDirectionIsForward + requires ValidConvSignature struct Args { constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -204,53 +203,4 @@ struct Args } }; -/// @brief `Inputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Inputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Inputs -{ - void* input; - void* weight; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("input", args.make_input_descriptor(), &Inputs::input); - inspect("weight", args.make_weight_descriptor(), &Inputs::weight); - } -}; - -/// @brief `Outputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Outputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Outputs -{ - void* output; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("output", args.make_output_descriptor(), &Outputs::output); - } -}; - -/// @brief `init_inputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_inputs() -template - requires ValidConvSignature && ConvDirectionIsForward -void init_inputs(const Args& args, Inputs inputs) -{ - init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp new file mode 100644 index 0000000000..ce5811c87a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/error.hpp" + +/// This file deals with the backward weight-specific details of running grouped +/// convolution backwards weight operations. It mainly defines the data +/// structures (`Input` and `Output`), initialization, and validation. Note +/// that for this operation specifically, many of the operations are +/// implemented automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Inputs +{ + void* input; + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("output", args.make_output_descriptor(), &Inputs::output); + } +}; + +/// @brief `Outputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Outputs +{ + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("weight", args.make_weight_descriptor(), &Outputs::weight); + } +}; + +/// @brief `init_inputs()` specialization for backwards convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp new file mode 100644 index 0000000000..0b1ffeb707 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp @@ -0,0 +1,276 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// bwd grouped convolution operations in old CK. The main item is the +/// `run()` function, which is the main implementation used to invoke +/// CK grouped forward convolution kernels. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // TODO: Actually support multiple d + {}, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // TODO: Multiple D lengths/strides + {}, + {}, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +} // namespace detail + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightMultipleDInstance = + detail::CkConvBwdWeightMultipleDInstance; + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// This overload is specialized for Multiple-D. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + {}, // TODO + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + {}, // TODO + {}, // TODO + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp similarity index 52% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index a8f6825524..133d7d69b7 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -3,9 +3,8 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" #include @@ -28,9 +27,39 @@ namespace detail { /// namespace. template concept CkTileConvInstance = requires(Conv&) { + requires ValidConvSignature; { Conv::BlockSize() }; }; +template +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + const ck_tile::stream_config s_conf) +{ + using Conv = std::remove_reference_t; + const auto param = args.to_ck_tile_conv_param(); + + ck_tile::GroupedConvHostArgs + host_args(param, input, weight, {}, output, args.k_batch); + + auto kargs = Conv::MakeKernelArgs(host_args); + + const dim3 grids = Conv::GridSize(kargs); + const dim3 blocks = Conv::BlockSize(); + + if(!Conv::IsSupportedArgument(kargs)) + return RunResult::not_supported("unsupported ck_tile arguments"); + + constexpr index_t minimum_occupancy = + Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; + + return RunResult::from_runtime(ck_tile::launch_kernel( + s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +} + } // namespace detail /// @brief Concept for checking whether a convolution is invoked like CK Tile. @@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance; /// @brief `run()` specialization for forward convolution and CK Tile. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template - requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkTileConvInstance auto& conv, + requires ConvDirectionIsForward +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, const ck_tile::stream_config s_conf = {}) { - using Conv = std::remove_reference_t; - const auto param = args.to_ck_tile_conv_param(); + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(inputs.weight), + static_cast(outputs.output), + s_conf); +} - ck_tile::GroupedConvFwdHostArgs<> host_args( - param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch); - - auto kargs = Conv::MakeKernelArgs(host_args); - - const dim3 grids = Conv::GridSize(kargs); - const dim3 blocks = Conv::BlockSize(); - - if(!Conv::IsSupportedArgument(kargs)) - { - std::cout << "Not supported!"; - return std::make_tuple(false, 0.f); - } - - constexpr index_t minimum_occupancy = - Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; - - return std::make_tuple( - true, - ck_tile::launch_kernel( - s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +/// @brief `run()` specialization for backwards weight convolution and CK Tile. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ConvDirectionIsBackwardWeight +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const ck_tile::stream_config s_conf = {}) +{ + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + s_conf); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp new file mode 100644 index 0000000000..b81892c91e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" + +/// This file deals with the forward-specific details of running grouped +/// convolution forward operations. It mainly defines the data structures +/// (`Input` and `Output`), initialization, and validation. Note that +/// for this operation specifically, many of the operations are implemented +/// automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Inputs +{ + void* input; + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("weight", args.make_weight_descriptor(), &Inputs::weight); + } +}; + +/// @brief `Outputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Outputs +{ + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("output", args.make_output_descriptor(), &Outputs::output); + } +}; + +/// @brief `init_inputs()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp similarity index 73% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp index f911dca21c..5eca79508c 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp @@ -3,14 +3,14 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include #include /// This file contains the implementation details for invoking/testing -/// grouped convolution operations in old CK. The main item is the +/// fwd grouped convolution operations in old CK. The main item is the /// `run()` function, which is the main implementation used to invoke /// CK grouped forward convolution kernels. @@ -18,10 +18,9 @@ namespace ck_tile::builder::test { namespace detail { -/// @brief Concept for checking whether this is the reference convolution -/// implementation. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// -/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except /// with some utility aliases. For that reason, its moved to this detail /// namespace. template > -concept CkConvInstance = requires(Conv& conv, - // TODO: This should be changed depending on IsMultiA etc. - // Currently that is not yet supported elsewhere anyway. - const void* p_a, - const void* p_b, - void* p_e, - std::array lengths, - std::array strides, - std::array filter, - Ops::InElementwiseOp elementwise_a, - Ops::WeiElementwiseOp elementwise_b, - Ops::OutElementwiseOp elementwise_cde) { +concept CkConvFwdInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { + requires ValidConvSignature; + requires ConvDirectionIsForward; + { conv.MakeArgument(p_a, p_b, @@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv, } // namespace detail -/// @brief Concept for checking whether a convolution is invoked like old CK. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is /// likely to be an "old CK" implementation - that is, whether we should @@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv, /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. template -concept CkConvInstance = detail::CkConvInstance; +concept CkConvFwdInstance = detail::CkConvFwdInstance; /// @brief `run()` specialization for forward convolution and old CK. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkConvInstance auto& conv, +[[nodiscard]] RunResult run(CkConvFwdInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, @@ -126,6 +125,9 @@ std::tuple run(CkConvInstance auto& conv, const auto weight_desc = args.make_weight_descriptor(); const auto output_desc = args.make_output_descriptor(); + if(args.k_batch != 1) + return RunResult::not_supported("ck fwd does not support k_batch != 1"); + auto ck_args = conv.MakeArgument(inputs.input, inputs.weight, {}, @@ -147,11 +149,9 @@ std::tuple run(CkConvInstance auto& conv, args.cde_elementwise_op); if(!conv.IsSupportedArgument(ck_args)) - { - std::cout << "invalid argument" << std::endl; - } + return RunResult::not_supported("unsupported ck arguments"); - return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf)); + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp new file mode 100644 index 0000000000..169d0741ff --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp @@ -0,0 +1,137 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors. +template +concept RefConvInstance = requires(Conv& conv, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + ck::utils::conv::ConvParam param) { + requires ValidConvSignature; + { conv.Run(input, weight, output, param) }; +}; + +/// @brief Generic `run` implementation for forward/backwards reference kernels. +/// +/// @tparam SIGNATURE The signature of the operation to perform. +/// +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f for reference). +/// @see run() +template +[[nodiscard]] RunResult +run(RefConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + if(!args.make_input_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv"); + + if(!args.make_weight_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv"); + + if(!args.make_output_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(input, weight, output, param); + return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime. +} + +} // namespace detail + +/// @brief Concept for checking whether this is the reference convolution +/// forward implementation. +template +concept RefConvFwdInstance = + detail::RefConvInstance && + ConvDirectionIsForward; + +/// @brief `run()` specialization for forward convolution and the reference +/// forward implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +[[nodiscard]] RunResult run(RefConvFwdInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, inputs.weight, outputs.output); +} + +/// @brief Concept for checking whether this is the reference convolution +/// backward weight implementation. +template +concept RefConvBwdWeightInstance = + detail::RefConvInstance && + ConvDirectionIsBackwardWeight; + +/// @brief `run()` specialization for forward convolution and the reference +/// backward weight implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(RefConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, outputs.weight, inputs.output); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp deleted file mode 100644 index ff276f7c9c..0000000000 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include -#include - -/// This file contains the implementation details for invoking/testing -/// grouped convolution operations using the reference implementation. -/// The main item is the `run()` function, which is the primary way to -/// invoke the reference execution mechanism. -/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, -/// but its made specific to the reference implementation, which is -/// invoked in a slightly different way. - -namespace ck_tile::builder::test { - -/// @brief Concept for checking whether this is the reference convolution -/// implementation. -/// -/// This concept is used to tell whether a convolution implementation is -/// likely to be the reference implementation - that is, whether we should -/// invoke it like the reference kernel. This is mainly used with `run()` to -/// differentiate which implementation that should be invoked. -/// -/// - SIGNATURE is the operation signature. -/// - Conv is a convolution instance created by the CK Builder API. -template -concept RefConvInstance = requires(Conv& conv, - const void* input, - const void* weight, - void* output, - ck::utils::conv::ConvParam param) { - { conv.Run(input, weight, output, param) }; -}; - -/// @brief `run()` specialization for forward convolution and the reference -/// implementation. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f for reference). -/// @see run() -template - requires ValidConvSignature && - // TODO: Maybe we can unify this implementation for bwd/weight too? - // for now, just concern outselves with reference and see when the - // rest of the bwd/weight plumbing is there. - ConvDirectionIsForward -std::tuple run(RefConvInstance auto& conv, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) -{ - // We don't want to compute the output dims manually, just get - // them via the existing infrastructure - const auto param = args.to_ck_conv_param(); - - // TODO: The reference convolution is currently missing a few features. - // Just throw for now, but regard these as TODO items that should be resolved - // eventually. - - if(!args.make_input_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_weight_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_output_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - - conv.Run(inputs.input, inputs.weight, outputs.output, param); - return std::make_tuple(true, 0.0f); -} - -} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 2976e6c14b..35fc1f4ee8 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -12,6 +12,7 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" #include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck/utility/data_type.hpp" diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index e61d7c4da5..307871b47a 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -3,7 +3,11 @@ #pragma once +#include #include +#include +#include +#include #include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" @@ -288,6 +292,57 @@ ValidationReport validate(const Args& args, Outputs actual, Outputs expected) = delete; +/// @brief This structure represents the result of a run operation. +/// +/// The structure contains multiple fields with information about +/// how the operation completed (or not). See those for more info. +struct RunResult +{ + /// If this value is not set to `std::nullopt`, there was a problem + /// while running the algorithm. In this case, the outputs are not + /// valid (though may be partially or completely overwritten), and + /// the optional contains a short debug message that indicates the + /// problem. + std::optional error = std::nullopt; + + /// The runtime of the kernel in milliseconds, if measured. Whether the + /// runtime is measured at all depends on the stream configuration + /// passed to run(). 0 if not measured or if there was an error. This + /// value is averaged over the total amount of runs actually done. Again, + /// this is usually configured via the stream config. + float runtime = 0.f; + + /// @brief Utility function for constructing a RunResult from an unsupported operation. + /// + /// @param msg A short debug message that will be included in the result. + constexpr static RunResult not_supported(std::string_view msg) + { + return RunResult{.error = std::string(msg)}; + } + + /// @brief Utility function for constructing a RunResult from an average runtime, + /// indicating a successful operation. + /// + /// @param runtime The runtime of the kernel in milliseconds. + constexpr static RunResult from_runtime(const float runtime) + { + return RunResult{.runtime = runtime}; + } + + /// @brief Returns whether this algorithm executed successfully. + /// + /// In this case there should be no message in `error`. + bool is_supported() const { return !this->error.has_value(); } +}; + +inline std::ostream& operator<<(std::ostream& os, const RunResult& result) +{ + if(result.error.has_value()) + return os << "invalid run (" << result.error.value() << ")"; + else + return os << "successful run (" << result.runtime << " ms)"; +} + /// @brief Invoke a device operation created by CK Builder. /// /// This is the main function used to invoke a particular device operation @@ -318,13 +373,14 @@ ValidationReport validate(const Args& args, /// @param outputs The output tensor data. The contents will be overwritten by /// this function. /// @param s_conf Stream config used to launch kernel. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @note This function is explicitly deleted to generate compile errors /// for missing implementations. +/// +/// @see RunResult template -std::tuple run(Operation& operation, +[[nodiscard]] RunResult run(Operation& operation, const Args& args, const Inputs& inputs, const Outputs& outputs, diff --git a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp index 81d5b7a6f5..076b5e9751 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/testing.hpp" + /// testing.hpp requires developers of a type of SIGNATURE to implement /// quite a lot of functionality for each SIGNATURE. For example, next /// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 158f271e21..8410a71b15 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -51,6 +51,9 @@ struct ValidationReport /// The number of elements which were bitwise 0. uint64_t zero_elements; + // Max error. + double max_error; + /// @brief Check whether both the output and reference tensor were both all zeros. /// /// If both tensors are all zero, it indicates either an incorrect testing setup @@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name, // Initial pass: count errors // Allocate and reset counter - auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); - check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + auto d_counters = alloc_buffer(sizeof(uint64_t) * 3); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3)); auto d_error_count = &reinterpret_cast(d_counters.get())[0]; auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + auto d_max_error = &reinterpret_cast(d_counters.get())[2]; tensor_foreach(descriptor.get_lengths(), [=](auto index) { using CKType = typename factory::internal::DataTypeToCK
::type; @@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name, const auto r = static_cast(type_convert(b)); const auto err = std::abs(o - r); + atomicMax(d_max_error, err); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { // We expect the number of errors to be very low, so just use an atomic @@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name, check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); uint64_t zero_count = 0; check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + double max_error = 0; + check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost)); // TODO: Gather detailed coordinates. @@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name, .wrong_elements = error_count, .total_elements = descriptor.get_element_size(), .zero_elements = zero_count, + .max_error = max_error, }); return reports_.back().is_ok(); diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c4cca05e52..dad123bae5 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -157,6 +157,7 @@ enum class PipelineVersion V3, V4, V5, + V6, WEIGHT_ONLY }; @@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver) case V3: return "V3"; case V4: return "V4"; case V5: return "V5"; + case V6: return "V6"; case WEIGHT_ONLY: return "WEIGHT_ONLY"; default: return "Unknown"; } diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 9890563859..73a682f10c 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - ) +) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) set(BWD_WEIGHT_TESTS diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 4ad97209e5..a3f4a988ef 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -1,23 +1,30 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; + using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, .direction = ckb::ConvDirection::BACKWARD_WEIGHT, .data_type = ckb::DataType::BF16, .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCW}}, + .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = NGKW}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} @@ -30,14 +37,58 @@ constexpr auto ALGORITHM = using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(BwdWeight_1DBf16_CShuffle_V3, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", expected_transfer_parameters, "Filter1x1Stride1Pad0", - "NGCW,GKXC,NGKW", + "GNWC,GKXC,GNWK", "PassThrough,PassThrough,PassThrough", "Intrawave", "v2"}); } + +TEST(BwdWeight_1DBf16_CShuffle_V3, Execution) +{ + if(!ck_tile::get_device_name().starts_with("gfx9")) + { + // Note: XDL kernel + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 64}, + .filter = {.width = 1}, + }, + .filter_strides = {.width = 1}, + .filter_dilation = {.width = 1}, + .input_left_pad = {.width = 0}, + .input_right_pad = {.width = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 3e5e39191e..51bc45c29b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -4,8 +4,9 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, @@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create) "MNKPadding"}); } -TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) +TEST(Fwd2DFp16_CShufV3_GNHWC, Execution) { if(!ck_tile::get_device_name().starts_with("gfx9")) { + // Note: XDL kernel GTEST_SKIP() << "unsupported architecture"; } @@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 292d852b91..60dc45545f 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -1,35 +1,47 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "testing_utils.hpp" -namespace { +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; -using namespace ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; -TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(ckb::TileConvSpecialization::DEFAULT) + .with_tile_thread_block(cku::TileThreadBlock_64x64x64) + .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(cku::TileTransfer_4x4x4) + .with_tile_optimizations(ckt::TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +using Reference = ckb::ConvBuilder::Instance; + +TEST(BwdWeight_2D_FP16_NHWGC, Create) { - constexpr ConvSignature BwdWeightConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto BwdWeightConvAlgorithm = - ConvAlgorithm_Tile_GroupedConvolutionKernel{} - .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(TileThreadBlock_64x64x64) - .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(TileTransfer_4x4x4) - .with_tile_optimizations(TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - - using Builder = ConvBuilder; - run_ck_tile_test({ + cku::run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", "NHWGC_GKYXC_NHWGK", @@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_ }); } -} // namespace +TEST(BwdWeight_2D_FP16_NHWGC, Execution) +{ + ckt::Args args = { + .lengths = + { + .batch_size = 2, + .groups = 4, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 32, .height = 56}, + .filter = {.width = 3, .height = 3}, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp index 128744dcc6..650c217b71 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp @@ -4,8 +4,8 @@ #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); - EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get())); + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/test_testing_utils.cpp b/experimental/builder/test/test_testing_utils.cpp index 43bbbd69eb..100122eef3 100644 --- a/experimental/builder/test/test_testing_utils.cpp +++ b/experimental/builder/test/test_testing_utils.cpp @@ -5,11 +5,14 @@ #include "testing_utils.hpp" +namespace ckt = ck_tile::builder::test; + using ck_tile::test::HipError; using ck_tile::test::HipSuccess; using ck_tile::test::InstanceMatcher; using ck_tile::test::InstanceSet; using ck_tile::test::StringEqWithDiff; +using ck_tile::test::SuccessfulRun; TEST(InstanceSet, FromFactory) { @@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic) EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue))); EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue))); } + +TEST(RunResultMatcher, Basic) +{ + EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun()); + EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun())); +} + +TEST(RunResultMatcher, ExplainMatchResult) +{ + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult( + SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener)); + EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error")); +} diff --git a/experimental/builder/test/testing_utils.cpp b/experimental/builder/test/testing_utils.cpp index b60c35333e..e9677e5940 100644 --- a/experimental/builder/test/testing_utils.cpp +++ b/experimental/builder/test/testing_utils.cpp @@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const return ::testing::MakeMatcher(new HipStatusMatcher(error)); } +bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const +{ + if(actual.error.has_value() && listener) + *listener << "run failed: " << actual.error.value(); + + return actual.is_supported(); +} + +void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; } + +void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; } + +::testing::Matcher SuccessfulRun() +{ + return ::testing::MakeMatcher(new RunResultMatcher()); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index b84d53b6df..55de133a2a 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +/// @brief RunResult matcher +/// +/// `ckt::run` returns a RunResult which indicates whether there was any +/// problem while running the algorithm. This matcher is used to match those +/// values. +struct RunResultMatcher : public ::testing::MatcherInterface +{ + bool MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override; +}; + +/// @brief Construct a Google Test matcher that checks that a ckt::run result +/// was successful. +::testing::Matcher SuccessfulRun(); + template struct ReferenceOutputMatcher : public ::testing::MatcherInterface> @@ -180,6 +197,21 @@ struct ReferenceOutputMatcher if(listener->IsInterested() && !errors.empty()) { *listener << errors.size() << " tensors failed to validate"; + + for(const auto& e : errors) + { + *listener << "\n - " << e.tensor_name << ": "; + + if(e.is_all_zero()) + *listener << "all elements in actual and expected tensors are zero"; + else + { + // Round to 2 digits + const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f; + *listener << e.wrong_elements << "/" << e.total_elements + << " incorrect elements (~" << percentage << "%)"; + } + } } return errors.empty(); diff --git a/experimental/builder/test/unit_conv_fwd_testing.cpp b/experimental/builder/test/unit_conv_fwd_testing.cpp index be95a29a2d..9fc07568b4 100644 --- a/experimental/builder/test/unit_conv_fwd_testing.cpp +++ b/experimental/builder/test/unit_conv_fwd_testing.cpp @@ -3,7 +3,7 @@ #include "impl/conv_signature_types.hpp" #include "testing_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" #include "ck_tile/builder/testing/tensor_foreach.hpp" #include #include diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp index a83d034ac2..0dad8593fb 100644 --- a/experimental/builder/test/unit_validation.cpp +++ b/experimental/builder/test/unit_validation.cpp @@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect) testing::StringMatchResultListener listener; EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); - EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); + EXPECT_THAT(listener.str(), + StringEqWithDiff( // + "1 tensors failed to validate\n" + " - a: 625/625 incorrect elements (~100%)")); } diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc index 4b4c144428..ae451caec0 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc @@ -1,5 +1,6 @@ #include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc index 6b8024fa93..016ef3e653 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc @@ -2,8 +2,6 @@ using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; -auto conv = Instance{}; -bool is_supported; -float avg_time; -std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf); -return std::make_tuple(is_supported, avg_time, conv.GetInstanceString()); +auto conv = Instance{}; +ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf); +return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString()); diff --git a/include/ck/BUILD_TIME_OPTIMIZATION.md b/include/ck/BUILD_TIME_OPTIMIZATION.md new file mode 100644 index 0000000000..94b292b878 --- /dev/null +++ b/include/ck/BUILD_TIME_OPTIMIZATION.md @@ -0,0 +1,225 @@ +# Build Time Optimization + +Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575) + +This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase. + +## Why Build Time Matters + +Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time. + +## Key Types + +This codebase uses compile-time types to enable zero-overhead abstractions: + +- `Number` - compile-time integer, enables static dispatch and compile-time arithmetic +- `Sequence` - compile-time integer sequence, used for dimension ordering and index manipulation +- `Tuple` - heterogeneous container holding different types, used for tensor descriptors and transforms + +These types allow the compiler to fully unroll loops, eliminate branches, and inline all operations - producing GPU kernels with no runtime abstraction cost. + +## Optimization Techniques + +### 1. Replace Recursive Templates with Pack Expansion + +Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next: + +``` +sequence_gen_impl<5, F> + → sequence_gen_impl<4, F> + → sequence_gen_impl<3, F> + → ... +``` + +Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation. + +**Before** (O(N) recursive instantiation): + +```cpp +template +struct sequence_gen_impl +{ + using type = typename sequence_gen_impl{}), Is...>::type; +}; + +template +struct sequence_gen_impl<0, F, Is...> +{ + using type = Sequence; +}; +``` + +**After** (constant depth using compiler intrinsic + pack expansion): + +```cpp +namespace detail { + +template +struct sequence_gen_helper +{ + // Apply functor F to all indices via pack expansion + // F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number{}) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + +template +struct sequence_gen +{ + // __make_integer_seq produces + // sequence_gen_helper with constant depth + using type = + typename __make_integer_seq::template apply; +}; +``` + +Note: This document assumes C++17 or later. While `std::make_integer_sequence` (introduced in C++14) is the standard library facility for generating integer sequences, it only produces `std::integer_sequence`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly. + +### 2. Replace Lambdas with Named Functors + +Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses. + +**Before** (lambda creates unique instantiations at each call site): + +```cpp +// The lambda inside transform_tensor_descriptor: +generate_tuple([](auto i) { return Sequence{}; }, Number{}); +``` + +**After** (named functor shares instantiations): + +```cpp +// Define functor once +struct generate_identity_sequence +{ + template + __host__ __device__ constexpr auto operator()(Number) const + { + return Sequence{}; + } +}; + +// Use everywhere - shares instantiations +generate_tuple(generate_identity_sequence{}, Number{}); +``` + +This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). + +**Example: container_concat** + +```cpp +// Before: lambda creates unique type per call site +// (unpack2 applies a functor to all elements from both tuples) +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2([](auto&&... zs) { return make_tuple(forward(zs)...); }, tx, ty); +} + +// After: named functor shares instantiations +struct make_tuple_functor +{ + template + __host__ __device__ constexpr auto operator()(Ts&&... xs) const + { + return make_tuple(forward(xs)...); + } +}; + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2(make_tuple_functor{}, tx, ty); +} +``` + +This reduced `container_concat` instantiations from 186 to 93 (50% reduction). + +**Example: make_uniform_tuple** + +For patterns that create tuples with repeated values: + +```cpp +// Before: unique lambda type at each call site +generate_tuple([](auto) { return some_value; }, Number{}); + +// After: dedicated helper function +template +__host__ __device__ constexpr auto make_uniform_tuple(T&& value) +{ + return detail::make_uniform_tuple_impl(static_cast(value), make_index_sequence{}); +} + +// Usage +make_uniform_tuple(some_value); +``` + +### 3. Use Constexpr Loops Instead of Template Recursion + +Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation. + +**Before** (O(N) template instantiations): + +```cpp +// Simplified example - actual CK code used more complex recursive patterns +template +struct find_source_index_impl +{ + static constexpr index_t value = + (Seq::template At() == Target) ? Pos : find_source_index_impl::value; +}; + +template +struct find_source_index_impl +{ + static constexpr index_t value = -1; // not found +}; +``` + +**After** (single instantiation with constexpr loop): + +```cpp +template +__host__ __device__ constexpr index_t find_source_index(Sequence) +{ + // Simplified example - actual implementation handles empty sequences + constexpr index_t values[] = {Is...}; + for(index_t i = 0; i < sizeof...(Is); ++i) + if(values[i] == Target) return i; + return -1; // not found +} +``` + +This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. + +### 4. Use Fold Expressions for Accumulation + +Fold expressions (C++17) can replace recursive template patterns for accumulation operations. + +**Before** (uses helper utilities that hide template recursion: `generate_tuple` recursively constructs a tuple of N elements, and `container_reduce` recursively reduces that tuple): + +```cpp +const auto element_space_size = container_reduce( + generate_tuple([&](auto i) { + return (lengths[i] - Number<1>{}) * strides[i]; + }, Number{}), + math::plus{}, Number<1>{}); +``` + +**After** (single fold expression): + +```cpp +template +__host__ __device__ constexpr auto compute_element_space_size( + const Tuple& lengths, + const Tuple& strides, + Sequence) +{ + return (LongNumber<1>{} + ... + + ((lengths[Number{}] - Number<1>{}) * strides[Number{}])); +} +``` + +This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%. diff --git a/include/ck/config.h.in b/include/ck/config.h.in index f5421e7d5e..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,9 +55,6 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif -#ifndef CK_ENABLE_TF32 -#define CK_ENABLE_TF32 "ON" -#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -88,10 +85,6 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif -#ifndef CK_ENABLE_TF32 -#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ -#endif - #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 701c786c86..1c322fe4a7 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal // check if src element is valid const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - using vector_t = typename vector_type_maker::type::type; - dst_vector_type op_r_v; // Load data from memory in src_vector first - src_vector_container src_vector = - src_vector_container{grid_buf.template Get( - src_coord_.GetOffset(), true)}; + auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; + src_vector_container src_vector = src_vector_container{ + grid_buf.template Get(index, true)}; // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { @@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - dvgpr_.template SetAsType( - vgpr_data_idx_seq, - is_src_valid ? op_r_v.template AsType()[I0] : vector_t(0)); + src_dvgpr_.template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; + // OOB check + static_ford{}([&](auto ordered_src_access_idx) { + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + const bool is_src_valid = + oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); + dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); + }); + // make forward steps // forward step for each iteration just add 1 const auto dst_forward_steps = generate_tuple( @@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal dst_buf.template Set( dst_coord_.GetOffset(), true, - dvgpr_.template GetAsType(vgpr_data_idx_seq)); + dst_dvgpr_.template GetAsType(vgpr_data_idx_seq)); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); } + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto access_lengths_as_tuple = + container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); + + return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); + } + static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){}; using ThreadScratchData = StaticTensorTupleOfVectorBuffer; - ThreadScratchData dvgpr_; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + using OOBThreadScratch = StaticTensorTupleOfVectorBuffer; + + ThreadScratchData src_dvgpr_; + ThreadScratchData dst_dvgpr_; + OOBThreadScratch oob_thread_scratch_; SrcCoord src_coord_; DstCoord dst_coord_; const ElementwiseOperation element_op_; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 58da96e2f0..eadfa29c9f 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,8 +11,6 @@ namespace ck { namespace tensor_operation { namespace device { -#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 - template #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" #include "ck/utility/tuple.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - Tuple<>{}, // p_d0s_grid - arg.p_b1_grid + b1_batch_offset, - Tuple<>{}, // p_d1s_grid - arg.p_c_grid + c_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.b1_grid_desc, - Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.c_element_op, - arg.block_2_ctile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes C = A * B0 * B1 // MN = MK * KL * LN // ^^^^^^ (Acc0) @@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC) - : BatchStrideA_(BatchStrideA), - BatchStrideB0_(BatchStrideB0), - BatchStrideB1_(BatchStrideB1), - BatchStrideC_(BatchStrideC) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB0_; - index_t BatchStrideB1_; - index_t BatchStrideC_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType + Tuple<>, // D1sDataType + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0GridDesc - B1GridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, Tuple<>, // Ds1GridDesc - CGridDesc_M_N, + typename DeviceGemmGemmCommonBase::CGridDesc_M_N, // Tiling Family MPerBlock, LPerBlock, @@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + Tuple<>, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType, + Tuple<>, // D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - const B1DataType* p_b1_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - index_t StrideB1, - index_t StrideC, - index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_b1_grid{p_b1_grid_}, - p_c_grid{p_c_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - c_g_m_o_lengths = arr3{batch_count, M, O}; - c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides); - c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1); - } - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - const B1DataType* p_b1_grid; - CDataType* p_c_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - arr3 c_g_m_o_lengths; - arr3 c_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - B1GridDesc b1_grid_desc; - CGridDesc_M_N c_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B layout must be Column\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, - arg.b1_grid_desc, - Tuple<>{}, - arg.c_grid_desc_m_n, - arg.block_2_ctile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto c_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto c_stride_lowest = arg.c_g_m_o_strides[2]; - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - c_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_gemm_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm(static_cast(p_a), - static_cast(p_b0), - static_cast(p_b1), - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideB1, - StrideC, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - BatchStrideC, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + + std::array p_d0_grid{}; + std::array p_d1_grid{}; + std::array StrideD0s{}, BatchStrideD0s{}; + std::array StrideD1s, BatchStrideD1s{}; + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0_grid, + static_cast(p_b1), + p_d1_grid, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideC, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..a739af898f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,902 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_e1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx))); + + auto [p_d0s_grid, p_d1s_grid] = [&]() { + if constexpr(IsMultiD) + { + auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) { + static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) { + const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In)); + grid_pointer(In) = arg_grid(In) + batch_offset; + }); + return std::move(grid_pointer); + }; + auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx); + }; + auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx); + }; + auto d0s_grid = create_grid(ck::integral_constant{}, + get_d0_base_ptr, + arg.p_d0s_grid, + GridwiseOp::MakeD0sGridPointer()); + auto d1s_grid = create_grid(ck::integral_constant{}, + get_d1_base_ptr, + arg.p_d1s_grid, + GridwiseOp::MakeD1sGridPointer()); + return std::make_pair(d0s_grid, d1s_grid); + } + else + { + return std::make_pair(Tuple<>{}, Tuple<>{}); + } + }(); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + p_d0s_grid, + arg.p_b1_grid + b1_batch_offset, + p_d1s_grid, + arg.p_c_e1_grid + c_e1_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, + arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.cde1_element_op, + arg.block_2_etile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common +{ + static constexpr ck::index_t NumD0Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD0Tensor; + } + return 0; + }(); + static constexpr ck::index_t NumD1Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD1Tensor; + } + return 0; + }(); + + struct GridDescriptorCreator + { + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, + const std::array& d0_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); + } + + __host__ __device__ static auto MakeD0sGridDescriptor( + const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, + const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto MakeD1sGridDescriptor( + const std::array, NumD1Tensor>& d1_g_m_o_lengths_vec, + const std::array, NumD1Tensor>& d1_g_m_o_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto + MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, + const std::array& e1_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); + } + }; + + using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {})); + using D0sGridDesc = + remove_cvref_t; + using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {})); + using D1sGridDesc = + remove_cvref_t; + using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {})); + using CGridDesc_M_N = + decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB0_(BatchStrideB0), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_E1_(BatchStrideC) + { + } + + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideC_E1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_E1_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d0_idx]); + } + + template + __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideC_E1_; + }; +}; + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg +{ + using GridwiseGemm = typename DeviceOp::GridwiseOp; + using Common = + DeviceGemmGemm_Wmma_CShuffleV3_Common; + + static constexpr auto NumD0Tensor = Common::NumD0Tensor; + static constexpr auto NumD1Tensor = Common::NumD1Tensor; + + struct Argument : public BaseArgument + { + using arr3 = std::array; + + Argument(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + std::array p_d0s_grid_, + const B1DataType* p_b1_grid_, + std::array p_d1s_grid_, + CE1DataType* p_e1_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CDE1ElementwiseOperation cde1_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_d0s_grid{}, + p_b1_grid{p_b1_grid_}, + p_d1s_grid{}, + p_c_e1_grid{p_e1_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + cde1_element_op{cde1_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + e1_g_m_o_lengths = arr3{batch_count, M, O}; + e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] + + a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths, + a_g_m_k_strides); + b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths, + b0_g_n_k_strides); + b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths, + b1_g_o_n_strides); + c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor( + e1_g_m_o_lengths, e1_g_m_o_strides); + c_e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_e1_grid_desc_m_n); + block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1); + + if constexpr(IsMultiD) + { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + + // D0s layout [batch_count, M, N] + d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; + d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; + + // D0 pointer + p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); + }); + // D0 desc + d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor( + d0s_g_m_n_lengths, d0s_g_m_n_strides); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1DataType = remove_cvref_t>; + + // D1s layout [batch_count, M, O] + d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; + d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; + + // D1 pointer + p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); + }); + // D1 desc + d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor( + d1s_g_m_o_lengths, d1s_g_m_o_strides); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc); + } + } + + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + typename GridwiseGemm::D0sGridPointer p_d0s_grid; + const B1DataType* p_b1_grid; + typename GridwiseGemm::D1sGridPointer p_d1s_grid; + CE1DataType* p_c_e1_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + std::array d0s_g_m_n_lengths; + std::array d0s_g_m_n_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + std::array d1s_g_m_o_lengths; + std::array d1s_g_m_o_strides; + arr3 e1_g_m_o_lengths; + arr3 e1_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CDE1ElementwiseOperation cde1_element_op; + + // Grid descriptors and other mem calculators + typename Common::AGridDesc a_grid_desc; + typename Common::B0GridDesc b0_grid_desc; + std::conditional_t> d0s_grid_desc; + typename Common::B1GridDesc b1_grid_desc; + typename Common::D1sGridDesc d1s_grid_desc; + std::conditional_t< + IsMultiD, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + Tuple<>> + d1s_grid_desc_mblock_mperblock_nblock_nperblock; + + std::conditional_t + c_e1_grid_desc_m_n; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_e1_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map; + + typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tail_num = decltype(tail_number)::value; + const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3; + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + // check if DsLayout is supported + template + static constexpr bool CheckDLayout() + { + bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if constexpr(IsMultiD) + { + if constexpr(!(is_same_v)) + { + print("DeviceOp: B0 layout must be Column\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D0s layout must be Row\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D1s layout must be Row\n"); + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto cde1_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; + + // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major + // and the lowest dimension stride is hardcoded to 1 + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + e1_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + else + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + Tuple<>{}, + arg.b1_grid_desc, + Tuple<>{}, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto c_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto c_stride_lowest = arg.e1_g_m_o_strides[2]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp index 06651c0c0e..83fec9c95f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -3,91 +3,20 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t e1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx))); - - auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer(); - auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer(); - - static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset; - }); - - static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset; - }); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - p_d0s_grid, - arg.p_b1_grid + b1_batch_offset, - p_d1s_grid, - arg.p_e1_grid + e1_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.cde1_element_op, - arg.block_2_etile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes: // Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...) // E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...) @@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal // to LPerWmma (A.k.a Gemm0 NPerWmma). static constexpr index_t NPerWmma = LPerWmma; - // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler - // Transform operator or just not use one at all. - using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< - Sequence<1, 1, 1, 1, 1>, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, - const std::array& d0_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); - } - - __host__ __device__ static auto MakeD0sGridDescriptor( - const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, - const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto MakeD1sGridDescriptor( - const std::array, NumD0Tensor>& d1_g_m_o_lengths_vec, - const std::array, NumD0Tensor>& d1_g_m_o_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto - MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, - const std::array& e1_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using D0sGridDesc = remove_cvref_t; - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using D1sGridDesc = remove_cvref_t; - using E1GridDesc = decltype(MakeE1GridDescriptor({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1) - : BatchStrideA0_(BatchStrideA0), - BatchStrideB0_(BatchStrideB0), - BatchStrideD0s_(BatchStrideD0s), - BatchStrideB1_(BatchStrideB1), - BatchStrideD1s_(BatchStrideD1s), - BatchStrideE1_(BatchStrideE1) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA0_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - template - __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, - Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD0s_[d1_idx]); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE1_); - } - - template - __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD1s_[d1_idx]); - } - - private: - index_t BatchStrideA0_; - index_t BatchStrideB0_; - std::array BatchStrideD0s_; - index_t BatchStrideB1_; - std::array BatchStrideD1s_; - index_t BatchStrideE1_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, // InMemory Data Descriptor - AGridDesc, - B0GridDesc, - D0sGridDesc, - B1GridDesc, - D1sGridDesc, - E1GridDesc, + typename DeviceGemmGemmCommonBase::AGridDesc, + typename DeviceGemmGemmCommonBase::B0GridDesc, + typename DeviceGemmGemmCommonBase::D0sGridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, + typename DeviceGemmGemmCommonBase::D1sGridDesc, + typename DeviceGemmGemmCommonBase::E1GridDesc, // Tiling Family MPerBlock, LPerBlock, @@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, - Transform::matrix_padder.PadN, + DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer>; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + D0sLayout, + B1Layout, + D1sLayout, + E1Layout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + E1DataType, + D0sDataType, + D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + CDE0BlockTransferSrcScalarPerVector, + CShuffleBlockTransferScalarPerVector_NPerBlock, + true>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - std::array p_d0s_grid_, - const B1DataType* p_b1_grid_, - std::array p_d1s_grid_, - E1DataType* p_e1_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - std::array StrideD0s, - index_t StrideB1, - std::array StrideD1s, - index_t StrideE1, - index_t BatchStrideA, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CDE1ElementwiseOperation cde1_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_d0s_grid{}, - p_b1_grid{p_b1_grid_}, - p_d1s_grid{}, - p_e1_grid{p_e1_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - cde1_element_op{cde1_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - e1_g_m_o_lengths = arr3{batch_count, M, O}; - e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides); - e1_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n); - block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1); - - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - using D0DataType = remove_cvref_t>; - - // D0s layout [batch_count, M, N] - d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; - d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; - - // D0 pointer - p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); - - // D0 desc - d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]); - }); - - static_for<0, NumD1Tensor, 1>{}([&](auto i) { - using D1DataType = remove_cvref_t>; - - // D1s layout [batch_count, M, O] - d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; - d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; - - // D1 pointer - p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); - - // D1 desc - d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]); - }); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc); - } - - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - typename GridwiseOp::D0sGridPointer p_d0s_grid; - const B1DataType* p_b1_grid; - typename GridwiseOp::D1sGridPointer p_d1s_grid; - E1DataType* p_e1_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - std::array d0s_g_m_n_lengths; - std::array d0s_g_m_n_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - std::array d1s_g_m_o_lengths; - std::array d1s_g_m_o_strides; - arr3 e1_g_m_o_lengths; - arr3 e1_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CDE1ElementwiseOperation cde1_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - D0sGridDesc d0s_grid_desc; - B1GridDesc b1_grid_desc; - D1sGridDesc d1s_grid_desc; - typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock; - - E1GridDesc e1_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - // check if DsLayout is supported - template - static constexpr bool CheckDLayout() - { - bool valid = true; - // iterate over DLayout tuple - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - // if RefLayout and DLayout are same, keep valid true, otherwise false - valid = valid && is_same_v; - }); - return valid; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B0 layout must be Column\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D0s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D1s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc, - arg.e1_grid_desc_m_n, - arg.block_2_etile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto cde1_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; - - // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major - // and the lowest dimension stride is hardcoded to 1 - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - e1_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; - } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - static auto MakeArgument(const ADataType* p_a0, const B0DataType* p_b0, std::array p_d0s, @@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) { - return RawArg{p_a0, p_b0, - p_d0s, p_b1, - p_d1s, p_e1, - MRaw, NRaw, - KRaw, Gemm1NRaw, - Batch, StrideA0, - StrideB0, StrideD0s, - StrideB1, StrideD1s, - StrideE1, BatchStrideA0, - BatchStrideB0, BatchStrideD0s, - BatchStrideB1, BatchStrideD1s, - BatchStrideE1, a0_element_op, - b0_element_op, cde0_element_op, - b1_element_op, cde1_element_op}; + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; } // polymorphic @@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation c_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b0), - p_d0s, - static_cast(p_b1), - p_d1s, - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideD0s, - StrideB1, - StrideD1s, - StrideE1, - BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 126d107725..ae247f4e31 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 227a8aedd9..593a908498 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index 59a820861c..fb1ca3127e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index e8e3b69cb5..85ca16b293 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index f0216c3f71..81f505b594 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemmWelford::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 317c4073df..28c9f2bddc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index e96ec58cba..c09befa717 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index e09c69d052..377f792979 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1(&arg)); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bbf62d5fbe..dfdfd53725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, AComputeType, BComputeType, - false, - false>; + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ - AComputeType, false, false + AComputeType, false, false, false, true using GridwiseGemmCTranspose = std::conditional_t()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bc072a7019..f662ff834f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( @@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index d3bf2a364a..033e0b8745 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 3f8093afe1..b2ae092c27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 976b6f1ef8..3e19b082ee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index cea62ef281..f898db218f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -568,7 +568,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -585,6 +584,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -594,7 +596,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1373,12 +1374,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif // check device if constexpr(DirectLoad) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index 5ae9eaf8ac..6b5776c4eb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 bool supported = true; for(index_t i = 0; i < arg.group_count_; ++i) { + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer( + arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer( + arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + std::array placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 39024d39e4..99a18e07fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK(c / scale_in_ / scale_wei_ / scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float / scale_in_ / scale_wei_ / scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1656,6 +1663,13 @@ struct ConvScale e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float * scale_in_ * scale_wei_ * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1683,6 +1697,15 @@ struct ConvScaleRelu e = type_convert(x * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + float x; + Relu{}.template operator()(x, c_float * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index e47bb37a89..caf468d6cb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -132,10 +132,6 @@ struct ABTransferWaveTiles index_t, index_t) { - // Notes: padding is currently not supported with transpose - static_assert(!((PadMN || PadK) && ABDoTranspose), - "padding is currently not supported with transpose"); - const index_t MN_grid = !PadMN ? sizeMN : MNPad; const index_t K_grid = !PadK ? sizeK : KPad; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 5431c054fa..bcf131003c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base WmmaSelector::selected_wmma .wave_size; + __host__ __device__ static constexpr bool AWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && + !IsBPreShuffled; + } + + __host__ __device__ static constexpr bool BWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + } + // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default with transpose #ifdef __gfx12__ - static constexpr bool IsAWaveTransferApplicable = - !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; + static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); - static constexpr bool IsBWaveTransferApplicable = - !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable(); static constexpr bool IsWaveTileInterleavedFitting = (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); @@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return de_grid_desc_mblock_mperblock_nblock_nperblock; } + // Conditions for Wave Transfer with transpose: + // - 16 bit type: K % 8 == 0 (4 subtiles of 8x8) + // - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16) + __host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K) + { + if constexpr(AWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + pass &= !(sizeof(ADataType_) == 1 && + !(M % (2 * ABlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + + __host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K) + { + if constexpr(BWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + pass &= !(sizeof(BDataType_) == 1 && + !(N % (2 * BBlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ static constexpr bool CheckValidity(const Argument& karg, diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048..3a45d52bd3 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -199,55 +199,113 @@ template using make_index_sequence = typename __make_integer_seq::seq_type; -// merge sequence -template -struct sequence_merge +// merge sequence - optimized to avoid recursive instantiation +// +// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1) +// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why: +// +// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each +// element can be computed independently: output[i] = f(i) +// +// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths. +// To compute output[i], we need to know: +// 1. Which input sequence contains this index +// 2. The offset within that sequence +// This requires computing cumulative sequence lengths, which requires recursion/iteration. +// +// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth: +// - Base cases handle 1-4 sequences directly (O(1) for common cases) +// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...) +// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences +// +// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to +// linear dependency chain, so binary tree is superior. +// +namespace detail { + +// Helper to concatenate multiple sequences in one step using fold expression +template +struct sequence_merge_impl; + +// Base case: single sequence +template +struct sequence_merge_impl> { - using type = typename sequence_merge::type>::type; + using type = Sequence; }; +// Two sequences: direct concatenation template -struct sequence_merge, Sequence> +struct sequence_merge_impl, Sequence> { using type = Sequence; }; -template -struct sequence_merge +// Three sequences: direct concatenation (avoids one level of recursion) +template +struct sequence_merge_impl, Sequence, Sequence> { - using type = Seq; + using type = Sequence; }; -// generate sequence +// Four sequences: direct concatenation +template +struct sequence_merge_impl, Sequence, Sequence, Sequence> +{ + using type = Sequence; +}; + +// General case: binary tree reduction (O(log N) depth instead of O(N)) +template +struct sequence_merge_impl +{ + // Merge pairs first, then recurse + using left = typename sequence_merge_impl::type; + using right = typename sequence_merge_impl::type; + using type = typename sequence_merge_impl::type; +}; + +} // namespace detail + +template +struct sequence_merge +{ + using type = typename detail::sequence_merge_impl::type; +}; + +template <> +struct sequence_merge<> +{ + using type = Sequence<>; +}; + +// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation +namespace detail { + +// Helper that applies functor F to indices and produces a Sequence +// __make_integer_seq produces sequence_gen_helper +template +struct sequence_gen_helper +{ + // Apply a functor F to all indices at once via pack expansion (O(1) depth) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + template struct sequence_gen { - template - struct sequence_gen_impl - { - static constexpr index_t NRemainLeft = NRemain / 2; - static constexpr index_t NRemainRight = NRemain - NRemainLeft; - static constexpr index_t IMiddle = IBegin + NRemainLeft; + using type = + typename __make_integer_seq::template apply; +}; - using type = typename sequence_merge< - typename sequence_gen_impl::type, - typename sequence_gen_impl::type>::type; - }; - - template - struct sequence_gen_impl - { - static constexpr index_t Is = G{}(Number{}); - using type = Sequence; - }; - - template - struct sequence_gen_impl - { - using type = Sequence<>; - }; - - using type = typename sequence_gen_impl<0, NSize, F>::type; +template +struct sequence_gen<0, F> +{ + using type = Sequence<>; }; // arithmetic sequence @@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1> using type = typename __make_integer_seq::type; }; -// uniform sequence +// uniform sequence - optimized using __make_integer_seq +namespace detail { + +template +struct uniform_sequence_helper +{ + // Apply a constant value to all indices via pack expansion + template + using apply = Sequence<((void)Is, Value)...>; +}; + +} // namespace detail + template struct uniform_sequence_gen { - struct F - { - __host__ __device__ constexpr index_t operator()(index_t) const { return I; } - }; + using type = typename __make_integer_seq:: + template apply; +}; - using type = typename sequence_gen::type; +template +struct uniform_sequence_gen<0, I> +{ + using type = Sequence<>; }; // reverse inclusive scan (with init) sequence diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index d0735a32f6..f3d73e84a7 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -20,6 +20,7 @@ struct tuple_concat, Tuple> using type = Tuple; }; +// StaticallyIndexedArrayImpl uses binary split for O(log N) depth template struct StaticallyIndexedArrayImpl { diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a1be8027b2..2ba3b1e7c3 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -19,7 +19,7 @@ namespace ck_tile { /** @brief Maximum number of error values to display when checking errors */ -constexpr int ERROR_DETAIL_LIMIT = 128; +constexpr int ERROR_DETAIL_LIMIT = 16; /** @brief 8-bit floating point type */ using F8 = ck_tile::fp8_t; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index c4ab1d4a78..34d18cb8e1 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); else return make_static_tile_distribution( - tile_distribution_encoding< // + tile_distribution_encoding< sequence, tuple, sequence>, diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b52..9d19e902e5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; + // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr } }; + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; + + static constexpr auto ALdsTileDistr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + static constexpr auto BLdsTileDistr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) + { + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + + auto a_lds_gemm_window = make_tile_window( + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); + auto b_lds_gemm_window = make_tile_window( + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); + + load_int4_tile( + a_warp_tile_, a_lds_gemm_window); + load_int4_tile( + b_warp_tile_, b_lds_gemm_window); + } + + // C += A * B with quantization support + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as corresponding " + "C block tensor data type!"); + constexpr auto warp_size = get_warp_size(); + + // Track which KRepeat chunk is currently loaded + index_t current_k_repeat_loaded = -1; + + // Restructured loop: M → N → QScale → KIterPerQScale + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Iterate over quantization groups + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + + // Accumulate K iterations for this quantization group + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + // Map quantization indices to global K iteration + constexpr auto kIterGlobal = + kQScale * Traits::KIterPerQScale + kIterInQScale; + + // Map to KRepeat chunk and KInnerLoopIter offset + constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter; + constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter; + + // Prefetch new chunk if needed + if constexpr(kInnerIdx == 0) + { + if(current_k_repeat_loaded != kRepeatIdx) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kRepeatIdx != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + current_k_repeat_loaded = kRepeatIdx; + } + } + + // Load A warp tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // Load B warp tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Synchronization barrier at the end of last iteration + if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 && + kIterInQScale == Traits::KIterPerQScale - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // Accumulate: first iteration initializes, rest accumulate + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Set priority for scheduling + if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // Apply quantization scale after accumulating all K iterations for this + // group + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + AQPickerCommon aq_picker( + aq_block_tensor); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; + public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { @@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr MakeCBlockTile(); } - template @@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr bool_constant a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + if constexpr(Scheduler == GemmPipelineScheduler::Interwave) + { + block_gemm_impl_.template LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } + else + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f7..b87c12c14a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const OverrideADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 4284e7622f..3f59e2d036 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -392,8 +392,4 @@ struct BlockReduce2D InDataType reduce_init; }; -// deduction guide -template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; - } // namespace ck_tile diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp index abb95934ff..58e768b319 100644 --- a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp @@ -40,7 +40,7 @@ struct BlockSoftmax2D #endif // compute row max - auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; + auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; #if _BLOCK_SOFTMAX_USE_UNPACK2 auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{}); #else diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp index aecf519c10..5210265cef 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp @@ -10,49 +10,55 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward data convolution kernel working with packed (contiguous) tensors -// Computes gradients w.r.t. input from output gradients and weights -// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], -// output[G][N][K][spatial] +// Optimized backward data convolution kernel working with packed (contiguous) tensors with +// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes +// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial] template -__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - const OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, + const WeiDataType* const* __restrict__ p_weis, + const OutDataType* const* __restrict__ p_outs, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t x = 0; x < X; ++x) { @@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t wo = w_tmp / stride_x; if(wo >= 0 && wo < Wo) { - const OutDataType* out_gnk = out_gn; - const WeiDataType* wei_gkc = wei_g + c * wei_stride_c; + // Pointers at current filter position + const OutDataType* output_grad_g_n_k = output_grad_g_n; + const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c; for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnk[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkc[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_g_n_k, + p_outs + 1, + g * out_stride_g + n * out_stride_n, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_g_k_c, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } } } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val; } } @@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t y = 0; y < Y; ++y) { @@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkh = out_gn + ho * out_stride_h; - const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h; + const WeiDataType* weight_at_c_y = + weight_g + c * wei_stride_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnkh[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra + // A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_outs + 1, + g * out_stride_g + n * out_stride_n + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op(in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + hi * p_d_strides[0][3] + + wi * p_d_strides[0][4]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] = in_val; } @@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t z = 0; z < Z; ++z) { @@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t do_idx = d_tmp / stride_z; if(do_idx >= 0 && do_idx < Do) { - const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d; - const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z; + // Pointers at current spatial depth + const OutDataType* output_grad_at_d = + output_grad_g_n + do_idx * out_stride_d; + const WeiDataType* weight_at_c_z = + weight_g + c * wei_stride_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { @@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; + const WeiDataType* weight_at_c_z_y = + weight_at_c_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, - out_gnkdh[k * out_stride_k + wo]); - wei_op(wei_val, - wei_gkczy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation + // with extra A tensors + detail::apply_multi_tensor_elementwise_op< + NumAExtra>(out_val, + out_op, + output_grad_at_d_h, + p_outs + 1, + g * out_stride_g + + n * out_stride_n + + do_idx * out_stride_d + + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with + // extra B tensors + detail::apply_multi_tensor_elementwise_op< + NumBExtra>( + wei_val, + wei_op, + weight_at_c_z_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + hi * in_stride_h + wi] = in_val; } } } -// GPU reference backward data convolution - takes ConvParam directly -template -void naive_conv_bwd_data(TIn* p_in, - const TWei* p_wei, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility +void naive_conv_bwd_data_multi_abd( + TIn* p_in, + const std::array& p_weis, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in, // Allocate packed buffers SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); - TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + + std::vector out_packed_bufs; + out_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + out_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_outs_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_outs_packed[i] = static_cast(out_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in, // Pack output and weight tensors to contiguous layout (inputs to bwd data) constexpr int block_size = 256; - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs, + p_outs_packed.data(), + (NumAElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in, if(ndim == 1) { - naive_conv_bwd_data_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_data_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in, } else // 3D { - naive_conv_bwd_data_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_data - now a zero-overhead wrapper +template +inline void naive_conv_bwd_data(TIn* p_in, + const TWei* p_wei, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_weis = {p_wei}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in, + p_weis, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp index f46b072baa..8cee2e2b77 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp @@ -10,49 +10,58 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward weight convolution kernel working with packed (contiguous) tensors +// Optimized backward weight convolution kernel working with packed (contiguous) tensors with +// multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial], // weight_grad[G][K][C][filter] // Computes gradient with respect to weights template -__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in, - WeiDataType* __restrict__ p_wei_grad, - const OutDataType* __restrict__ p_out_grad, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void +naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins, + WeiDataType* __restrict__ p_wei_grad, + const OutDataType* const* __restrict__ p_out_grads, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gn[wi]); - out_op(out_val, out_gn_k[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_n_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle output gradient element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_n_k, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val; } } @@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h; + // Pointers at current spatial height + const InDataType* input_at_h = input_at_n_c + hi * in_stride_h; + const OutDataType* output_grad_at_h = + output_grad_at_n_k + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - out_op(out_val, out_gn_kh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra B + // tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } @@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op(wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + y * p_d_strides[0][3] + + x * p_d_strides[0][4]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y + x] = wei_val; } @@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t do_idx = 0; do_idx < Do; ++do_idx) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_n_c + di * in_stride_d; + const OutDataType* output_grad_at_d = + output_grad_at_n_k + do_idx * out_stride_d; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - out_op(out_val, out_gn_kdh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra + // B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_d_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + do_idx * out_stride_d + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } @@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z + y * wei_stride_y + x] = wei_val; } } } -// GPU reference backward weight convolution - takes ConvParam directly -template -void naive_conv_bwd_weight(const TIn* p_in, - TWei* p_wei_grad, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility +void naive_conv_bwd_weight_multi_abd( + const std::array& p_ins, + TWei* p_wei_grad, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in, out_total *= l; // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut)); + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); + + std::vector out_grad_packed_bufs; + out_grad_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); TWei* p_wei_grad_packed = static_cast(wei_grad_packed_buf.GetDeviceBuffer()); - TOut* p_out_grad_packed = static_cast(out_grad_packed_buf.GetDeviceBuffer()); + + std::array p_out_grads_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_out_grads_packed[i] = static_cast(out_grad_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in, // Pack input and output_grad tensors to contiguous layout (inputs to bwd weight) constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], + p_out_grads_packed[i], + d_out_lengths, + d_out_strides, + dim_count, + out_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs, + p_out_grads_packed.data(), + (NumBElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in, if(ndim == 1) { - naive_conv_bwd_weight_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_weight_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else // 3D { - naive_conv_bwd_weight_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_weight - now a zero-overhead wrapper +template +inline void +naive_conv_bwd_weight(const TIn* p_in, + TWei* p_wei_grad, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_wei_grad, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp index 131b632a25..7bf9b49998 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp @@ -10,48 +10,56 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized convolution kernel working with packed (contiguous) tensors +// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], // output[G][N][K][spatial] template -__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_fwd_packed_multi_abd( + const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra) + const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra) + const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers + const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays + OutDataType* __restrict__ p_out, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gc = in_g + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gc[wi]); - wei_op(wei_val, wei_gkc[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val; } } @@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const InDataType* input_at_h = input_at_c + hi * in_stride_h; + const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - wei_op(wei_val, wei_gkcy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } @@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op(out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + ho * p_d_strides[0][3] + + wo * p_d_strides[0][4]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] = out_val; } @@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t z = 0; z < Z; ++z) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_c + di * in_stride_d; + const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - wei_op(wei_val, wei_gkczy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_z_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } @@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + ho * out_stride_h + wo] = out_val; } } } -// GPU reference convolution - takes ConvParam directly -template -void naive_conv_fwd(const TIn* p_in, - const TWei* p_wei, - TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility +void naive_conv_fwd_multi_abd( + const std::array& p_ins, + const std::array& p_weis, + const std::array& p_ds, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in, for(auto l : out_lengths) out_total *= l; - // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); + // Allocate packed buffers for all A and B tensors + // Use separate allocations to avoid copy assignment issues with RAII wrapper + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); + // Get packed buffer pointers + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); // Compute strides and allocate device arrays for pack/unpack @@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in, // Pack input and weight tensors to contiguous layout constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + // Pack all A tensors + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + // Pack all B tensors + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + // NOTE: D tensors are NOT packed - they are used directly with their original strides + // to support broadcasting (e.g., BiasGK layout with zero strides) + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + // Allocate and copy strides to device + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + // D tensors use original pointers (not packed) to support broadcasting + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in, if(ndim == 1) { - naive_conv_fwd_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in, } else if(ndim == 2) { - naive_conv_fwd_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in, } else // 3D { - naive_conv_fwd_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_fwd - now a zero-overhead wrapper +template +inline void naive_conv_fwd(const TIn* p_in, + const TWei* p_wei, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_weis = {p_wei}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_weis, + p_ds, + p_out, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp index 0a7b58b310..50b65357a2 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp @@ -22,9 +22,39 @@ struct SimpleDeviceMem HIP_CHECK_ERROR(hipMalloc(static_cast(&p_mem_), mem_size)); } + // Delete copy operations (resource should not be copied) + SimpleDeviceMem(const SimpleDeviceMem&) = delete; + SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete; + + // Define move operations + SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_) + { + other.p_mem_ = nullptr; + } + + SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept + { + if(this != &other) + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + p_mem_ = other.p_mem_; + other.p_mem_ = nullptr; + } + return *this; + } + void* GetDeviceBuffer() { return p_mem_; } - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + ~SimpleDeviceMem() + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + } void* p_mem_; }; @@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src, } } +namespace detail { + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result, + Op&& element_op, + const DataType* const* tensor_ptrs, + long_index_t element_offset, + std::index_sequence) +{ + element_op(result, tensor_ptrs[Is][element_offset]...); +} + +// Generic helper for A and B tensors (works in all directions) +template +__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result, + Op&& element_op, + const DataType* primary_ptr, + const DataType* const* extra_ptrs, + long_index_t extra_base_offset, + long_index_t element_offset) +{ + const DataType* tensor_ptrs[NumExtraTensors + 1]; + tensor_ptrs[0] = primary_ptr; + + static_for<1, NumExtraTensors + 1, 1>{}( + [&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; }); + + apply_multi_tensor_impl(result, + element_op, + tensor_ptrs, + element_offset, + std::make_index_sequence{}); +} + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out, + Op&& element_op, + float computed_value, + const float* d_values, + std::index_sequence) +{ + float temp_out; + element_op(temp_out, computed_value, d_values[Is]...); + result_out = type_convert(temp_out); +} + +// Specialized helper for D tensors with stride calculations and float conversion +template +__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out, + Op&& element_op, + float computed_value, + const DDataType* const* p_ds, + const index_t* const* p_d_strides, + index_t g, + index_t n, + index_t c_or_k, + long_index_t spatial_linear_index) +{ + if constexpr(NumDTensors == 0) + { + element_op(result_out, computed_value); + } + else + { + float d_values[NumDTensors]; + + // Compute all D tensor indices and convert to float + static_for<0, NumDTensors, 1>{}([&](auto i) { + const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] + + c_or_k * p_d_strides[i][2] + spatial_linear_index; + d_values[i] = type_convert(p_ds[i][d_idx]); + }); + + apply_d_tensor_impl(result_out, + element_op, + computed_value, + d_values, + std::make_index_sequence{}); + } +} + +} // namespace detail + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 745f8cbd32..970bcb0439 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -376,7 +376,7 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances = // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 18abcb1613..3b7ce0df3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -147,7 +147,7 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index d79fe9bfa3..d7b654a345 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index e284cbbb83..7d7966c47f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 6195d40f87..2f63199480 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index e51bec3dfb..b50e37cf0a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 66ba1e3830..4651068d86 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 8eccccf354..4dcbaccaa4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index e58c884729..9accf6e336 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -7,12 +7,17 @@ #include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp" #include "grouped_convolution_signatures.hpp" +#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/builder/conv_builder.hpp" +// Temporary disable builder validate since we don't have deduced rtol, atol support +#define ENABLE_BUILDER_VALIDATE 0 + namespace ck_tile::builder::profiling { namespace ckb = ck_tile::builder; @@ -113,25 +118,66 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, auto reference = ckt::alloc_outputs(args); using ReferenceInstance = typename ckb::ConvBuilder::Instance; - auto ref_conv = ReferenceInstance{}; - ckt::run(ref_conv, args, inputs, reference.get()); + auto ref_conv = ReferenceInstance{}; + [[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get()); + +#if ENABLE_BUILDER_VALIDATE == 0 + using DataType = + std::conditional_t>; + const auto conv_param = args.to_ck_tile_conv_param(); + + const std::size_t output_bytes_num = conv_param.template GetOutputByte(); + std::vector out(output_bytes_num / sizeof(DataType)); + std::vector ref(output_bytes_num / sizeof(DataType)); + HIP_CHECK_ERROR( + hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost)); + + const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(), + conv_param.filter_spatial_lengths_.cend(), + 1, + std::multiplies()) * + conv_param.C_; + float max_accumulated_value = *std::max_element(ref.begin(), ref.end()); + const auto rtol = ck_tile::get_relative_threshold(GemmK); + const auto atol = + ck_tile::get_absolute_threshold(max_accumulated_value, GemmK); +#endif [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); if(is_supported) { + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name + << std::endl; + +#if ENABLE_BUILDER_VALIDATE const auto errors = ckt::validate(args, outputs, reference.get()).get_errors(); for(const auto& error : errors) { valid = false; std::cout << "Number of incorrect values: " << error.wrong_elements - << " Is all zero:" << error.is_all_zero() << std::endl; + << " Is all zero:" << error.is_all_zero() + << " max err: " << error.max_error << std::endl; } - best_avg_time = std::min(best_avg_time, avg_time); - best_op_name = best_avg_time < avg_time ? best_op_name : op_name; - std::cout << "Perf: " << std::setw(10) << avg_time << " ms,"; +#else + HIP_CHECK_ERROR( + hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost)); + valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol); +#endif + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + else + { + std::cout << " " << op_name << std::endl; } - std::cout << " " << op_name << std::endl; }; if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD) diff --git a/profiler/include/profiler/grouped_convolution_signatures.hpp b/profiler/include/profiler/grouped_convolution_signatures.hpp index 5103b0f235..0f87e283bb 100644 --- a/profiler/include/profiler/grouped_convolution_signatures.hpp +++ b/profiler/include/profiler/grouped_convolution_signatures.hpp @@ -6,7 +6,7 @@ #include #include "../../experimental/builder/test/impl/conv_signature_types.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ck_tile::builder::profiling { diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index a0f9b9ac25..bf5ffcb5d2 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -17,6 +17,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" namespace ck { namespace profiler { @@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(output.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); - if(do_verification) + // profile device Conv instances + bool pass = true; + + if(do_verification == 1) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData gpu_ref_input(in_g_n_c_wis_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * + input_device_result.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data()); + } + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData gpu_ref_output(out_g_n_k_wos_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data()); + } using DeviceOp = ck::tensor_operation::device::DeviceConvFwd(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 3a9f14e595..afc88150ed 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; - // Calculate number of accumulations accounting for split_k - const int num_accums = - static_cast(output.GetElementSize() / conv_param.K_ / split_k_value); - - // Additional tolerance for split_k accumulation if needed - int total_accums = num_accums; - if(split_k_value > 1) - { - total_accums = std::max(num_accums, static_cast(split_k_value)); - } - - // Perform GPU verification (max value computed internally on GPU) + const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums_split_k = split_k_value; + // Get maximum accumulated value from reference const std::size_t tensor_size = weight_device_result.mDesc.GetElementSpaceSize(); + max_accumulated_value = + gpu_reduce_max(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size); + // Calculate thresholds + auto rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + auto atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + // Perform GPU verification auto gpu_result = - ck::profiler::gpu_verify( - wei_device_buf.GetDeviceBuffer(), - gpu_ref_wei_buf.GetDeviceBuffer(), - total_accums, - tensor_size); + ck::profiler::gpu_verify(wei_device_buf.GetDeviceBuffer(), + gpu_ref_wei_buf.GetDeviceBuffer(), + rtol, + atol, + tensor_size); if(!gpu_result) { diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 50cd58eec3..2a282edbc8 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, bias_device_buf.ToDevice(bias.mData.data()); // run reference op - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + if constexpr(BiasGK) + { + // For GK bias layout: G*K, zero strides for N and spatial dimensions + d_strides_vec[0] = K; + d_strides_vec[1] = 0; + d_strides_vec[2] = 1; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_strides_vec[3 + i] = 0; + } + } + else + { + // Full GNKHW layout - same as output + ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin()); + } + + std::array d_ptrs = { + reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp index 3f4905c110..b439428cda 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -22,6 +22,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl( wei_device_buf.ToDevice(weight.mData.data()); d_device_buf.ToDevice(d_tensor.mData.data()); - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< NDimSpatial, InDataType, @@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl( host_output(idx) = ck::type_convert(out_val); }); } + else if(do_verification == 2) + { + // GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + // D tensor has same layout as output + ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin()); + + std::array d_ptrs = { + reinterpret_cast(d_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DDataType>( // Explicitly specify D tensor type + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index acdc937a33..9444996c25 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -7,6 +7,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" @@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, std::cout << "scale_out: " << scale_out << std::endl; // run reference op - if(do_verification) + if(do_verification == 1) { std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; @@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, } }); } + else if(do_verification == 2) + { + // GPU reference + // WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference + // Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md) + if constexpr(std::is_same_v && + std::is_same_v) + { + // Compute conv to CShuffleDataType (float), then post-process on CPU + DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_c_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + PassThrough{}); + + ck::hip_check_error(hipDeviceSynchronize()); + + Tensor gpu_c(out_g_n_k_wos_desc); + gpu_ref_c_dev.FromDevice(gpu_c.mData.data()); + + // Post-process on CPU to match CPU reference behavior + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(gpu_c(idx)); + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + }); + } + else + { + // Normal path for non-int8 or non-Scale cases + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * + device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(host_output.mData.data()); + } + } std::string best_op_name; float best_avg_time = 0; @@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, best_gb_per_sec = gb_per_sec; } - if(do_verification) + if(do_verification == 1) { out_device_buf.FromDevice(device_output.mData.data()); @@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, << std::endl; } } + else if(do_verification == 2) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = + pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and GPU ref results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "gpu_ref_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/profiler/src/profile_grouped_conv_fwd_tile.cpp b/profiler/src/profile_grouped_conv_fwd_tile.cpp index 8023dcf2f6..1a1e8b769a 100644 --- a/profiler/src/profile_grouped_conv_fwd_tile.cpp +++ b/profiler/src/profile_grouped_conv_fwd_tile.cpp @@ -6,7 +6,7 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" diff --git a/script/analyze_build/README.md b/script/analyze_build/README.md new file mode 100644 index 0000000000..7a88b98e77 --- /dev/null +++ b/script/analyze_build/README.md @@ -0,0 +1,263 @@ +# Build Trace Analysis + +Simple to use, fast python tools for analyzing Clang `-ftime-trace` build performance data. + +## Overview + +We're kicking off a systematic effort to dramatically reduce CK and CK-Tile build times, [#3575](https://github.com/ROCm/composable_kernel/issues/3575). A key part of this work is improving our C++ metaprogramming to reduce the burden on the compiler. + +In order to prioritize work and measure our progress, we need data on template instantiation. For single files, Clang's `-ftime-trace` build performance data is easy to analyze with the Perfetto UI. The problem we are solving here is how to analyze instantiation data across thousands of compilation units. + +The python code in this directory provides helper functions to quickly load JSON files into pandas DataFrames that can be used for analysis in Jupyter notebooks. + +## Directory Structure + +``` +script/analyze_build/ +├── trace_analysis/ # Core library +│ ├── __init__.py # Main exports +│ ├── parse_file.py # Fast parsing of JSON trace files +│ ├── template_analysis.py # Template instantiation analysis +│ ├── template_parser.py # Template name parsing utilities +│ └── phase_breakdown.py # Compilation phase breakdown +├── notebooks/ # Jupyter notebooks for analysis +│ └── file_analysis_example.ipynb # Template analysis example +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +## Python Requirements + +See `requirements.txt` for the complete list of dependencies: +* **pandas** - DataFrame manipulation and analysis +* **orjson** - Fast JSON parsing for trace files +* **plotly** - Interactive visualizations (sunburst, treemap) +* **nbformat** - Jupyter notebook format support +* **ipykernel** - Kernel for running notebooks in VSCode/Jupyter +* **kaleido** - Static image export from Plotly charts +* **jupyter** - Full Jupyter environment + +## Quick Start + +### Setup + +1. Create a virtual environment (recommended): +```bash +cd script/analyze_build +python3 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Install VSCode extensions if you want to run notebooks in VSCode: + * Jupyter + * Data Wrangler (interact with Pandas DataFrames) + +### Analyzing a Single File + +Use the `parse_file` function to load a `-ftime-trace` JSON file into a Pandas DataFrame: + +```python +from trace_analysis import parse_file + +# Parse the trace file +df = parse_file('path/to/trace.json') + +# View basic info +print(f"Total events: {len(df)}") +print(df.columns) + +# Analyze duration statistics +print(df['dur'].describe()) +``` + +### Extracting Compilation Metadata + +Get high-level metadata about the compilation: + +```python +from trace_analysis import get_metadata + +# Extract metadata from trace file +metadata = get_metadata('trace.json') + +print(f"Source file: {metadata['source_file']}") +print(f"Compilation time: {metadata['total_wall_time_s']:.2f}s") +print(f"Started: {metadata['wall_start_datetime']}") +print(f"Ended: {metadata['wall_end_datetime']}") +``` + +The metadata includes: +- `source_file`: Main .cpp/.c file being compiled +- `time_granularity`: Time unit used ("microseconds") +- `beginning_of_time`: Epoch timestamp in microseconds +- `wall_start_time`: Wall clock start (microseconds since epoch) +- `wall_end_time`: Wall clock end (microseconds since epoch) +- `wall_start_datetime`: Human-readable start time +- `wall_end_datetime`: Human-readable end time +- `total_wall_time_us`: Total compilation time in microseconds +- `total_wall_time_s`: Total compilation time in seconds + +### Template Instantiation Analysis + +The module includes specialized functions for analyzing C++ template instantiation costs: + +```python +from trace_analysis import ( + parse_file, + get_template_instantiation_events, + get_phase_breakdown, +) + +df = parse_file('trace.json') + +# Get all template instantiation events with parsed template information +template_events = get_template_instantiation_events(df) + +# The returned DataFrame includes parsed columns: +# - namespace: Top-level namespace (e.g., 'std', 'ck') +# - template_name: Template name without parameters +# - full_qualified_name: Full namespace::template_name +# - param_count: Number of template parameters +# - is_ck_type: Boolean indicating CK library types +# - is_nested: Boolean indicating nested templates + +# Find slowest template instantiations +top_templates = template_events.nlargest(20, 'dur') +print(top_templates[['template_name', 'namespace', 'param_count', 'dur']]) + +# Analyze by namespace +namespace_summary = template_events.groupby('namespace').agg({ + 'dur': ['count', 'sum', 'mean'] +}) +print(namespace_summary) +``` + +### Compilation Phase Breakdown + +Analyze how compilation time is distributed across different phases: + +```python +from trace_analysis import get_phase_breakdown, PhaseBreakdown + +df = parse_file('trace.json') + +# Get hierarchical phase breakdown +breakdown = get_phase_breakdown(df) + +# Display in Jupyter (automatic rich HTML display) +display(breakdown) + +# Print text representation +print(breakdown) + +# Access the underlying DataFrame +print(breakdown.df) + +# Convert to plotly format for visualization +import plotly.express as px +data = breakdown.to_plotly() +fig = px.sunburst(**data) +fig.show() +``` + +The `PhaseBreakdown` class provides: +- Hierarchical breakdown of compilation phases +- Automatic calculation of "Other" residual time at each level +- Validation that children don't exceed parent durations +- Multiple output formats (text, DataFrame, Plotly) + +## DataFrame Schema + +The parsed DataFrame contains the following columns from the `-ftime-trace` format: + +- `name`: Event name (function, template instantiation, etc.) +- `ph`: Phase character ('X' for complete, 'B' for begin, 'E' for end, 'i' for instant) +- `ts`: Timestamp in microseconds +- `dur`: Duration in microseconds (for complete events) +- `pid`: Process ID +- `tid`: Thread ID +- `arg_*`: Flattened arguments from the event's `args` field + +### Template Event Columns + +When using `get_template_instantiation_events()`, additional parsed columns are included: + +- `namespace`: Top-level namespace extracted from the template name +- `template_name`: Template name without namespace or parameters +- `full_qualified_name`: Complete namespace::template_name +- `param_count`: Number of template parameters +- `is_ck_type`: Boolean flag for CK library types (namespace starts with 'ck') +- `is_nested`: Boolean flag indicating nested template instantiations + +## Use in Jupyter Notebooks + +The module is designed to work seamlessly in Jupyter notebooks. See `notebooks/file_analysis_example.ipynb` for a complete example workflow that demonstrates: + +- Loading and parsing trace files +- Extracting compilation metadata +- Analyzing phase breakdown with visualizations +- Template instantiation analysis with parsed columns +- Filtering and grouping by namespace +- Identifying CK-specific template costs + +To use in a notebook: + +```python +import sys +from pathlib import Path + +# Add trace_analysis to path +sys.path.insert(0, str(Path.cwd().parent)) + +from trace_analysis import ( + parse_file, + get_metadata, + get_template_instantiation_events, + get_phase_breakdown, +) + +# Load and analyze +df = parse_file('path/to/trace.json') +breakdown = get_phase_breakdown(df) +templates = get_template_instantiation_events(df) + +# Visualize +import plotly.express as px +fig = px.sunburst(**breakdown.to_plotly()) +fig.show() +``` + +## API Reference + +### Core Functions + +- `parse_file(filepath)`: Parse a `-ftime-trace` JSON file into a pandas DataFrame +- `get_metadata(filepath_or_df)`: Extract compilation metadata from trace file or DataFrame + +### Template Analysis + +- `get_template_instantiation_events(df)`: Filter to template instantiation events with parsed template information + +### Phase Breakdown + +- `get_phase_breakdown(df)`: Generate hierarchical compilation phase breakdown +- `PhaseBreakdown`: Class representing phase breakdown with multiple output formats + +## Contributing + +This is an experimental project for analyzing and improving C++ metaprogramming build times. Contributions are welcome! When adding new analysis functions: + +1. Add the function to the appropriate module in `trace_analysis/` +2. Export it in `__init__.py` +3. Update this README with usage examples +4. Consider adding a notebook example if the feature is substantial + +## License + +Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +SPDX-License-Identifier: MIT diff --git a/script/analyze_build/notebooks/file_analysis_example.ipynb b/script/analyze_build/notebooks/file_analysis_example.ipynb new file mode 100644 index 0000000000..e8d1ee3bcd --- /dev/null +++ b/script/analyze_build/notebooks/file_analysis_example.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Template Instantiation Analysis Example\n", + "\n", + "This notebook demonstrates how to use the template analysis functions to understand C++ template instantiation costs in Clang's `-ftime-trace` output.\n", + "\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add parent directory to path\n", + "sys.path.insert(0, str(Path.cwd().parent))\n", + "\n", + "from trace_analysis import (\n", + " parse_file,\n", + " get_template_instantiation_events,\n", + " get_phase_breakdown,\n", + " get_metadata,\n", + ")\n", + "\n", + "import pandas as pd\n", + "from datetime import datetime\n", + "import plotly.express as px\n", + "\n", + "\n", + "# Display settings\n", + "pd.set_option(\"display.max_rows\", 100)\n", + "pd.set_option(\"display.max_columns\", None)\n", + "pd.set_option(\"display.width\", None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Trace File" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load your trace file\n", + "trace_file = Path(\n", + " \"../../../build-trace/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeFiles/device_conv2d_fwd_instance.dir/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp.json\"\n", + ")\n", + "df = parse_file(trace_file)\n", + "\n", + "print(f\"Total events: {len(df):,}\")\n", + "starting_timestamp = datetime.fromtimestamp(df.attrs[\"beginningOfTime\"] / 1e6)\n", + "print(f\"Starting timestamp: {starting_timestamp.strftime('%Y-%m-%d:%H:%M:%S')}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_metadata(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compilation Overview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get phase breakdown and display it\n", + "breakdown = get_phase_breakdown(df)\n", + "print(breakdown)\n", + "display(breakdown)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract data for plotly charts (sunburst, tree-map, or icicle)\n", + "plotly_data = breakdown.to_plotly()\n", + "fig = px.sunburst(**plotly_data)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template Instantiation Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all template instantiation events (now with parsed columns!)\n", + "template_events = get_template_instantiation_events(df)\n", + "\n", + "print(f\"Total template instantiation events: {len(template_events):,}\")\n", + "print(f\"Total template time: {template_events['dur'].sum() / 1000:.1f} ms\")\n", + "display(template_events)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Examine Parsed Columns\n", + "\n", + "The `get_template_instantiation_events()` function automatically parses the `arg_detail` column into structured fields:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the new parsed columns\n", + "print(\"Parsed columns available:\")\n", + "print(\"- namespace: Top-level namespace (e.g., 'std', 'ck')\")\n", + "print(\"- template_name: Template name without parameters\")\n", + "print(\"- full_qualified_name: Full namespace::template_name\")\n", + "print(\"- param_count: Number of template parameters\")\n", + "print(\"- is_ck_type: Boolean indicating CK library types\")\n", + "print(\"- is_nested: Boolean indicating nested templates\")\n", + "print()\n", + "\n", + "# Display sample of parsed data\n", + "template_events[\n", + " [\"namespace\", \"template_name\", \"param_count\", \"is_ck_type\", \"is_nested\", \"dur\"]\n", + "].head(20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analysis by Namespace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Group by namespace to see where time is spent\n", + "namespace_summary = (\n", + " template_events.groupby(\"namespace\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"], \"param_count\": \"mean\"})\n", + " .round(2)\n", + ")\n", + "\n", + "namespace_summary.columns = [\"count\", \"total_dur\", \"avg_dur\", \"avg_params\"]\n", + "namespace_summary[\"total_ms\"] = namespace_summary[\"total_dur\"] / 1000\n", + "namespace_summary = namespace_summary.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTemplate Instantiation Time by Namespace:\")\n", + "display(namespace_summary)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CK Library Templates Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter to CK types only\n", + "ck_templates = template_events[template_events[\"is_ck_type\"]].copy()\n", + "\n", + "print(f\"CK template instantiations: {len(ck_templates):,}\")\n", + "print(f\"CK template time: {ck_templates['dur'].sum() / 1000:.1f} ms\")\n", + "print(\n", + " f\"Percentage of total template time: {100 * ck_templates['dur'].sum() / template_events['dur'].sum():.1f}%\"\n", + ")\n", + "print()\n", + "\n", + "# Top CK templates by time\n", + "ck_by_name = (\n", + " ck_templates.groupby(\"template_name\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"]})\n", + " .round(2)\n", + ")\n", + "ck_by_name.columns = [\"count\", \"total_dur\", \"avg_dur\"]\n", + "ck_by_name[\"total_ms\"] = ck_by_name[\"total_dur\"] / 1000\n", + "ck_by_name = ck_by_name.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTop CK Templates by Total Time:\")\n", + "display(ck_by_name.head(20))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/script/analyze_build/requirements.txt b/script/analyze_build/requirements.txt new file mode 100644 index 0000000000..fd99fdba09 --- /dev/null +++ b/script/analyze_build/requirements.txt @@ -0,0 +1,18 @@ +# Build Trace Analysis - Python Dependencies + +# Core data processing +pandas>=2.0.0 +orjson>=3.9.0 + +# Jupyter notebook support +nbformat>=4.2.0 +ipykernel>=6.0.0 + +# Interactive visualizations +plotly>=5.0.0 + +# Static image export from Plotly +kaleido>=0.2.0 + +# Full Jupyter environment (if not using VSCode) +jupyter>=1.0.0 diff --git a/script/analyze_build/trace_analysis/__init__.py b/script/analyze_build/trace_analysis/__init__.py new file mode 100644 index 0000000000..70db321083 --- /dev/null +++ b/script/analyze_build/trace_analysis/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build Trace Analysis - Core library for analyzing Clang -ftime-trace data. + +This package provides tools to parse and analyze Clang's -ftime-trace JSON output +for build performance analysis. +""" + +from .parse_file import ( + parse_file, + get_metadata, +) + +from .template_analysis import ( + get_template_instantiation_events, +) + +from .phase_breakdown import ( + get_phase_breakdown, + PhaseBreakdown, +) + +__all__ = [ + # Core parsing and filtering + "parse_file", + "get_metadata", + # Template analysis + "get_template_instantiation_events", + # Phase breakdown + "get_phase_breakdown", + "PhaseBreakdown", +] diff --git a/script/analyze_build/trace_analysis/parse_file.py b/script/analyze_build/trace_analysis/parse_file.py new file mode 100644 index 0000000000..24d71e4eb8 --- /dev/null +++ b/script/analyze_build/trace_analysis/parse_file.py @@ -0,0 +1,356 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Parse a single Clang -ftime-trace JSON file into a Pandas DataFrame. + +This module provides fast parsing of Clang's -ftime-trace output using orjson +for performance. The JSON file is typically a single-line array of trace events. +""" + +import orjson +import pandas as pd +from pathlib import Path +from typing import Union, Optional +from datetime import datetime +from dataclasses import dataclass + + +# Expected schema for trace event DataFrames with optimized dtypes +# This enforces strict column validation and memory-efficient types +# The memory usage is dominated by arg detail, but we optimize each series. +TRACE_EVENT_DTYPES = { + "pid": "int32", # Process ID (max observed: ~2.3M, fits in int32) + "tid": "int32", # Thread ID (max observed: ~2.3M, fits in int32) + "ts": "int64", # Timestamp in microseconds (requires int64 for epoch times) + "cat": "category", # Category (low cardinality, use categorical) + "ph": "category", # Phase type (very low cardinality: X, B, E, i, etc.) + "id": "int64", # Event ID + "name": "category", # Event name (medium cardinality, use categorical) + "dur": "int64", # Duration in microseconds (max 10 days = 864B μs, requires int64) + "arg_detail": "string", # Detail string (high cardinality, keep as string) + "arg_count": "int64", # Argument count + "arg_avg ms": "int64", # Average milliseconds + "arg_name": "category", # Argument name (medium cardinality, use categorical) +} + + +@dataclass +class FileMetadata: + """ + Processed metadata with computed fields for compilation analysis. + + This extends the raw metadata with derived values like formatted timestamps + and converted time units for convenience. + + Attributes: + source_file: Main .cpp/.c file being compiled + time_granularity: Time unit used in trace (always "microseconds" for Clang) + beginning_of_time: Epoch timestamp in microseconds from JSON root + execute_compiler_ts: Timestamp of ExecuteCompiler event (microseconds) + execute_compiler_dur: Duration of ExecuteCompiler event (microseconds) + total_wall_time_us: Total compilation time in microseconds (same as execute_compiler_dur) + total_wall_time_s: Total compilation time in seconds (computed from microseconds) + wall_start_time: Wall clock start time in microseconds since epoch (computed) + wall_end_time: Wall clock end time in microseconds since epoch (computed) + wall_start_datetime: Human-readable start time string (formatted) + wall_end_datetime: Human-readable end time string (formatted) + """ + + source_file: Optional[str] = None + time_granularity: str = "microseconds" + beginning_of_time: Optional[int] = None + execute_compiler_ts: Optional[int] = None + execute_compiler_dur: Optional[int] = None + total_wall_time_us: Optional[int] = None + total_wall_time_s: Optional[float] = None + wall_start_time: Optional[int] = None + wall_end_time: Optional[int] = None + wall_start_datetime: Optional[str] = None + wall_end_datetime: Optional[str] = None + + def __repr__(self): + # auto-generate pretty lines + fields = "\n".join( + f" {name} = {value!r}" for name, value in self.__dict__.items() + ) + return f"{self.__class__.__name__}(\n{fields}\n)" + + +def parse_file(filepath: Union[str, Path]) -> pd.DataFrame: + """ + Parse a Clang -ftime-trace JSON file into a Pandas DataFrame. + + The -ftime-trace format is a JSON array of trace events. Each event contains + fields like name, phase (ph), timestamp (ts), duration (dur), process/thread IDs, + and optional arguments (args). + + The beginningOfTime value from the JSON structure is automatically extracted + and stored in df.attrs['beginningOfTime']. Use get_metadata(df) to get + processed metadata with event-derived fields and computed values. + + Args: + filepath: Path to the -ftime-trace JSON file + + Returns: + DataFrame with columns for each event field. Nested 'args' are flattened + with an 'arg_' prefix. The beginningOfTime value is stored in + df.attrs['beginningOfTime']. + + Raises: + FileNotFoundError: If the file doesn't exist + ValueError: If the JSON is invalid or empty + + Examples: + >>> df = parse_file('build/trace.json') + >>> df[['name', 'dur']].head() + >>> + >>> # Access processed metadata + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> + >>> # Access beginningOfTime directly if needed + >>> beginning = df.attrs.get('beginningOfTime') + >>> print(f"Beginning of time: {beginning}") + """ + filepath = Path(filepath) + + if not filepath.exists(): + raise FileNotFoundError(f"Trace file not found: {filepath}") + + # Read and parse JSON using orjson for speed + with open(filepath, "rb") as f: + data = orjson.loads(f.read()) + + if not data: + raise ValueError(f"Empty trace data in file: {filepath}") + + # Handle both formats: direct array or {"traceEvents": [...]} + if isinstance(data, dict): + if "traceEvents" in data: + events = data["traceEvents"] + else: + raise ValueError( + f"Expected 'traceEvents' key in JSON object, got keys: {list(data.keys())}" + ) + elif isinstance(data, list): + events = data + else: + raise ValueError(f"Expected JSON array or object, got {type(data).__name__}") + + # Convert to DataFrame + df = pd.DataFrame(events) + + if df.empty: + raise ValueError(f"No trace events found in file: {filepath}") + + # Flatten 'args' column if it exists + if "args" in df.columns: + df = _flatten_args(df) + + # Validate schema: check for missing columns + expected_columns = set(TRACE_EVENT_DTYPES.keys()) + actual_columns = set(df.columns) + + missing_columns = expected_columns - actual_columns + if missing_columns: + raise ValueError( + f"Missing expected columns in trace data: {sorted(missing_columns)}" + ) + + # Validate schema: check for unexpected columns + unexpected_columns = actual_columns - expected_columns + if unexpected_columns: + raise ValueError( + f"Unexpected columns found in trace data: {sorted(unexpected_columns)}" + ) + + # Apply optimized dtypes with strict type enforcement + for col, dtype in TRACE_EVENT_DTYPES.items(): + if dtype in ("int64", "int32"): + # Fill missing values with 0 for integer columns, then convert to specified int type + df[col] = df[col].fillna(0).astype(dtype) + elif dtype == "category": + # Convert to categorical for memory efficiency with repeated values + df[col] = df[col].astype("category") + elif dtype == "string": + # Convert to pandas string dtype for memory efficiency + df[col] = df[col].astype("string") + else: + raise ValueError(f"Unsupported dtype '{dtype}' for column '{col}'") + + # Extract and store beginningOfTime in DataFrame attributes + df.attrs["beginningOfTime"] = ( + data.get("beginningOfTime") if isinstance(data, dict) else None + ) + + return df + + +def _flatten_args(df: pd.DataFrame) -> pd.DataFrame: + """ + Flatten the 'args' column into separate columns with 'arg_' prefix. + + The 'args' field in trace events contains additional metadata as a dictionary. + This function extracts those key-value pairs into separate columns. + + Args: + df: DataFrame with an 'args' column containing dictionaries + + Returns: + DataFrame with flattened args columns and original 'args' column removed + """ + # Extract args into separate DataFrame + args_data = [] + for idx, row in df.iterrows(): + args = row.get("args", {}) + if isinstance(args, dict): + args_data.append(args) + else: + args_data.append({}) + + if args_data: + args_df = pd.DataFrame(args_data) + # Prefix all args columns with 'arg_' + args_df.columns = [f"arg_{col}" for col in args_df.columns] + + # Drop original args column and concatenate flattened args + df = df.drop(columns=["args"]) + df = pd.concat([df, args_df], axis=1) + + return df + + +def _normalize_source_path(file_path: str) -> str: + """ + Normalize a source file path to be relative to composable_kernel if present. + + If 'composable_kernel' appears in the path, returns the path starting from + 'composable_kernel/'. Otherwise, returns the original path unchanged. + + Args: + file_path: Full filesystem path to a source file + + Returns: + Normalized path starting from composable_kernel, or original path if + composable_kernel is not found + + Examples: + >>> _normalize_source_path('/home/user/composable_kernel/include/ck/tensor.hpp') + 'composable_kernel/include/ck/tensor.hpp' + >>> _normalize_source_path('/usr/include/vector') + '/usr/include/vector' + """ + path = Path(file_path) + parts = path.parts + + # Find the last occurrence of 'composable_kernel' in the path + for i in range(len(parts) - 1, -1, -1): + if parts[i] == "composable_kernel": + # Return path from composable_kernel onwards + return str(Path(*parts[i:])) + + # If composable_kernel not found, return original path + return file_path + + +def get_metadata(df: pd.DataFrame) -> FileMetadata: + """ + Extract and process compilation metadata from a DataFrame. + + This function processes events from the DataFrame to extract compilation + information, then computes derived fields like formatted timestamps and + converted time units. + + Args: + df: DataFrame returned by parse_file() with beginningOfTime in its .attrs + + Returns: + FileMetadata instance with both raw and computed fields: + - source_file: Main .cpp/.c file being compiled (from events) + - time_granularity: Time unit used in trace ("microseconds") + - beginning_of_time: Epoch timestamp in microseconds from JSON root + - execute_compiler_ts: Timestamp of ExecuteCompiler event (from events) + - execute_compiler_dur: Duration of ExecuteCompiler event (from events) + - total_wall_time_us: Total compilation time in microseconds + - total_wall_time_s: Total compilation time in seconds (computed) + - wall_start_time: Wall clock start time (computed) + - wall_end_time: Wall clock end time (computed) + - wall_start_datetime: Human-readable start time (formatted) + - wall_end_datetime: Human-readable end time (formatted) + + Examples: + >>> df = parse_file('trace.json') + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> print(f"Started: {metadata.wall_start_datetime}") + """ + # Extract beginningOfTime from DataFrame attributes + beginning_of_time = None + if hasattr(df, "attrs"): + beginning_of_time = df.attrs.get("beginningOfTime") + + # Initialize metadata with beginningOfTime from JSON structure + metadata = FileMetadata(beginning_of_time=beginning_of_time) + + # Process events to extract ExecuteCompiler timing information + if "name" in df.columns: + execute_compiler = df[df["name"] == "ExecuteCompiler"] + if not execute_compiler.empty: + # Get the first ExecuteCompiler event + event = execute_compiler.iloc[0] + if "ts" in event: + metadata.execute_compiler_ts = event["ts"] + if "dur" in event: + metadata.execute_compiler_dur = event["dur"] + + # Process events to find the main source file being compiled + if "name" in df.columns and "arg_detail" in df.columns: + # Look for ParseDeclarationOrFunctionDefinition events with .cpp or .c files + source_extensions = (".cpp", ".cc", ".cxx", ".c") + parse_events = df[df["name"] == "ParseDeclarationOrFunctionDefinition"] + + for _, event in parse_events.iterrows(): + detail = event.get("arg_detail", "") + if detail: + # Extract file path (may include line:column info) + file_path = str(detail).split(":")[0] + + # Check if it's a source file (not a header) + if any(file_path.endswith(ext) for ext in source_extensions): + metadata.source_file = _normalize_source_path(file_path) + break + + # Compute derived fields + if metadata.execute_compiler_dur is not None: + metadata.total_wall_time_us = metadata.execute_compiler_dur + metadata.total_wall_time_s = metadata.execute_compiler_dur / 1_000_000.0 + + # Calculate wall clock times if we have the necessary data + if ( + metadata.beginning_of_time is not None + and metadata.execute_compiler_ts is not None + and metadata.execute_compiler_dur is not None + ): + metadata.wall_start_time = ( + metadata.beginning_of_time + metadata.execute_compiler_ts + ) + metadata.wall_end_time = ( + metadata.wall_start_time + metadata.execute_compiler_dur + ) + + # Convert to human-readable datetime strings + try: + start_dt = datetime.fromtimestamp(metadata.wall_start_time / 1_000_000.0) + end_dt = datetime.fromtimestamp(metadata.wall_end_time / 1_000_000.0) + metadata.wall_start_datetime = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[ + :-3 + ] + metadata.wall_end_datetime = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + except (OSError, ValueError): + # Handle invalid timestamps gracefully + pass + + return metadata diff --git a/script/analyze_build/trace_analysis/phase_breakdown.py b/script/analyze_build/trace_analysis/phase_breakdown.py new file mode 100644 index 0000000000..773ba06622 --- /dev/null +++ b/script/analyze_build/trace_analysis/phase_breakdown.py @@ -0,0 +1,354 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Phase breakdown analysis for Clang -ftime-trace data. + +This module provides hierarchical breakdown of compilation phases using +the pre-aggregated "Total" events from Clang's -ftime-trace output. + +The data is returned as a PhaseBreakdown object with rich display and +analysis capabilities optimized for Jupyter notebooks. +""" + +import pandas as pd +from collections import namedtuple +from typing import Optional + + +# Lightweight namedtuple for iteration +Phase = namedtuple("Phase", ["name", "depth", "duration", "duration_ms", "percentage"]) + + +class PhaseBreakdown: + """ + Wrapper for compilation phase breakdown with notebook-friendly API. + + Provides hierarchical view of compilation phases from Clang -ftime-trace, + with rich display, filtering, and visualization capabilities. + + Examples: + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter + >>> breakdown + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend + >>> + >>> # Get metrics + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Top N analysis + >>> breakdown.top(10) + >>> breakdown.frontend.top(5) + >>> + >>> # Visualization + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Iteration + >>> for phase in breakdown: + >>> print(f"{phase.name}: {phase.duration_ms:.1f}ms") + """ + + def __init__(self, df: pd.DataFrame): + """ + Initialize from phase breakdown DataFrame. + + Args: + df: DataFrame with columns name, parent, depth, duration + """ + if df.empty: + self._df = pd.DataFrame(columns=["name", "parent", "depth", "duration"]) + self._total_time = 0 + else: + self._df = df + self._total_time = self._get_total_time() + + def __repr__(self) -> str: + """Simple text representation for console.""" + if self._df.empty: + return "PhaseBreakdown(empty)" + n_phases = len(self._df) + return f"PhaseBreakdown({n_phases} phases, {self._total_time:.1f}ms total)" + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks.""" + if self._df.empty: + return "
PhaseBreakdown(empty)
" + return self.to_dataframe()._repr_html_() + + @property + def df(self) -> pd.DataFrame: + """ + Access underlying DataFrame. + + Returns: + DataFrame with columns name, parent, depth, duration + """ + return self._df + + def to_dataframe(self, show_percentages: bool = True) -> pd.DataFrame: + """ + Format as DataFrame for display. + + Creates a nicely formatted DataFrame suitable for Jupyter notebook display. + + Args: + show_percentages: Include percentage of total time + + Returns: + DataFrame with formatted columns + """ + return self._format_dataframe(show_percentages) + + def to_plotly(self) -> dict: + """ + Convert to plotly hierarchical visualization format. + + Returns a dictionary with data_frame, values, and path that can be directly + used with plotly.express sunburst, treemap, or icicle charts. + + Returns: + Dictionary with keys: data_frame, values, path, branchvalues + + Example: + >>> data = breakdown.to_plotly() + >>> import plotly.express as px + >>> + >>> # Create sunburst chart + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Create treemap chart + >>> fig = px.treemap(**data) + >>> fig.show() + >>> + >>> # Create icicle chart + >>> fig = px.icicle(**data) + >>> fig.show() + """ + return self._build_plotly_data() + + # Internal helper methods + + def _get_total_time(self) -> int: + """Get total time from root ExecuteCompiler event.""" + root = self._df[self._df["depth"] == 0] + if root.empty: + return 0 + return int(root.iloc[0]["duration"]) + + def _format_dataframe(self, show_percentages: bool) -> pd.DataFrame: + """Format phase breakdown as DataFrame.""" + if self._df.empty: + return pd.DataFrame() + + display_rows = [] + for _, row in self._df.iterrows(): + duration_ms = row["duration"] / 1000.0 + display_row = { + "Name": row["name"], + "Parent": row["parent"] if row["parent"] else "(root)", + "Depth": row["depth"], + "Duration (ms)": duration_ms, + } + if show_percentages and self._total_time > 0: + pct = row["duration"] / self._total_time * 100 + display_row["% of Total"] = pct + display_rows.append(display_row) + + display_df = pd.DataFrame(display_rows) + + if show_percentages: + display_df["% of Total"] = display_df["% of Total"].round(1) + + return display_df + + def _build_plotly_data(self) -> dict: + """Convert to plotly hierarchical visualization format.""" + return { + "data_frame": self._df, + "names": "name", + "parents": "parent", + "values": "duration", + "branchvalues": "total", + } + + +# Hierarchical phase specification +# There are over 100 totals in the JSON file, but a lot of them overlap. +# If the children total more than their parent, we will throw a ValueError. +# +# The hierarchy is specified as a nested dictionary where: +# - Keys are phase names (matching "Total " events in the trace) +# - Values are dictionaries of child phases (or empty dict {} for leaf nodes) +# - Empty string "" as a key means "calculate Other as residual" +# +# This structure supports arbitrary nesting depth. +PHASE_HIERARCHY = { + "ExecuteCompiler": { + "Frontend": { + "InstantiateFunction": {}, + }, + "Backend": { + "Optimizer": {}, + "CodeGenPasses": {}, + }, + } +} + + +def get_phase_breakdown(df: pd.DataFrame) -> PhaseBreakdown: + """ + Get hierarchical breakdown of compilation phases. + + Returns a PhaseBreakdown object with rich display and analysis methods, + using the pre-aggregated "Total" events from Clang's -ftime-trace output + for accurate statistics. + + All durations are in microseconds. + + The hierarchy is defined by the PHASE_HIERARCHY constant and supports + arbitrary nesting depth. The tree is traversed recursively to build + the phase breakdown. + + Args: + df: DataFrame from parse_file() + + Returns: + PhaseBreakdown object with rich display and analysis methods + + Raises: + ValueError: If required Total events are missing or if calculated + "Other" values are negative (indicating data inconsistency) + + Examples: + >>> df = parse_file('trace.json') + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter (automatic) + >>> breakdown + >>> + >>> # Get total compilation time + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend.top(5) + >>> + >>> # Visualize + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + """ + if "name" not in df.columns or "dur" not in df.columns: + raise ValueError("DataFrame missing required 'name' or 'dur' columns") + + # Pre-filter to Total events for efficient lookup + total_events = df[df["name"].str.startswith("Total ", na=False)].copy() + total_events["phase"] = total_events["name"].str.removeprefix("Total ") + + def get_duration(phase_name: str) -> Optional[int]: + """Get duration in microseconds from a Total event.""" + matches = total_events[total_events["phase"] == phase_name] + if matches.empty: + return None + return int(matches.iloc[0]["dur"]) + + def process_node( + node_name: str, + parent_name: str, + depth: int, + children_spec: dict, + ) -> list[dict]: + """ + Recursively process a node and its children in the phase hierarchy. + + Args: + node_name: Name of the current phase node + parent_name: Name of the parent phase (empty string for root) + depth: Current depth in the tree (0 for root) + children_spec: Dictionary of child phases to process + + Returns: + List of row dictionaries for this node and all descendants + + Raises: + ValueError: If phase not found or children exceed parent duration + """ + # Get duration for this node + node_duration = get_duration(node_name) + if node_duration is None: + raise ValueError(f"No Total {node_name} event found in trace") + + # Add current node + rows = [ + { + "name": node_name, + "parent": parent_name, + "depth": depth, + "duration": node_duration, + } + ] + + if not children_spec: + return rows + + # Process all children recursively + children_total = 0 + for child_name, grandchildren_spec in children_spec.items(): + if child_name == "": + # Empty string means "Other" - skip for now, calculate as residual + continue + + # Recursively process this child and its descendants + child_rows = process_node( + child_name, node_name, depth + 1, grandchildren_spec + ) + rows.extend(child_rows) + + # Track total duration of direct children only (not grandchildren) + children_total += child_rows[0]["duration"] + + # Calculate and add "Other" if there's unaccounted time + other_duration = node_duration - children_total + if other_duration < 0: + raise ValueError( + f"{node_name} children total ({children_total}) " + f"exceeds parent total ({node_duration})" + ) + + if other_duration > 0: + rows.append( + { + "name": f"{node_name}_Other", + "parent": node_name, + "depth": depth + 1, + "duration": other_duration, + } + ) + + return rows + + # Start recursive traversal from root + root_name = "ExecuteCompiler" + if root_name not in PHASE_HIERARCHY: + raise ValueError(f"Root phase '{root_name}' not found in PHASE_HIERARCHY") + + all_rows = process_node( + root_name, + "", # Root has no parent + 0, # Root is at depth 0 + PHASE_HIERARCHY[root_name], + ) + + breakdown_df = pd.DataFrame(all_rows) + return PhaseBreakdown(breakdown_df) diff --git a/script/analyze_build/trace_analysis/template_analysis.py b/script/analyze_build/trace_analysis/template_analysis.py new file mode 100644 index 0000000000..ef483f6f53 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_analysis.py @@ -0,0 +1,80 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template instantiation analysis for Clang -ftime-trace data. + +This module provides specialized functions for analyzing C++ template +instantiation costs from Clang's -ftime-trace output. +""" + +import pandas as pd +from .template_parser import parse_template_detail + + +def get_template_instantiation_events(df: pd.DataFrame) -> pd.DataFrame: + """ + Filter to template instantiation events and parse arg_detail into structured columns. + + Returns events for: + - InstantiateFunction: Function template instantiations + - InstantiateClass: Class template instantiations + + The returned DataFrame includes parsed columns from arg_detail: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Args: + df: DataFrame from parse_file() + + Returns: + Filtered DataFrame containing template instantiation events with parsed columns + + Example: + >>> df = parse_file('trace.json') + >>> templates = get_template_instantiation_events(df) + >>> templates.sort_values('dur', ascending=False).head(10) + >>> # Filter to CK types only + >>> ck_templates = templates[templates['is_ck_type']] + >>> # Group by template name + >>> templates.groupby('template_name')['dur'].sum() + """ + # Filter to template instantiation events + filtered_df = ( + df[ + df["name"].isin( + [ + "InstantiateClass", + "InstantiateFunction", + ] + ) + ] + .drop( + columns=[ + "arg_avg ms", + "arg_count", + "arg_name", + "cat", + "id", + "ph", + "pid", + "tid", + ] + ) + .reset_index(drop=True) + ) + + # Parse arg_detail into structured columns + parsed_data = filtered_df["arg_detail"].apply(parse_template_detail) + + # Convert list of dicts to DataFrame and join with original + parsed_df = pd.DataFrame(parsed_data.tolist()) + + # Combine with original data + result_df = pd.concat([filtered_df, parsed_df], axis=1) + + return result_df diff --git a/script/analyze_build/trace_analysis/template_parser.py b/script/analyze_build/trace_analysis/template_parser.py new file mode 100644 index 0000000000..2551465bd4 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_parser.py @@ -0,0 +1,301 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template detail string parser for C++ template instantiations. + +This module provides functions to parse the arg_detail strings from +Clang's -ftime-trace output into structured components. +""" + +import re +from typing import Dict + + +def parse_template_detail(detail_str: str) -> Dict[str, any]: + """ + Parse a template detail string into structured components. + + Args: + detail_str: The arg_detail string from -ftime-trace + + Returns: + Dictionary with parsed fields: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Example: + >>> parse_template_detail('std::basic_string') + { + 'namespace': 'std', + 'template_name': 'basic_string', + 'full_qualified_name': 'std::basic_string', + 'param_count': 1, + 'is_ck_type': False, + 'is_nested': False + } + """ + # Handle empty or invalid strings + if not detail_str or not isinstance(detail_str, str): + return _empty_result() + + # Remove surrounding quotes if present + detail_str = detail_str.strip('"') + + # Extract components + namespace = extract_namespace(detail_str) + template_name = extract_template_name(detail_str) + full_qualified_name = extract_full_qualified_name(detail_str) + param_count = count_template_params(detail_str) + is_ck = is_ck_template(detail_str) + is_nested = is_nested_template(detail_str) + + return { + "namespace": namespace, + "template_name": template_name, + "full_qualified_name": full_qualified_name, + "param_count": param_count, + "is_ck_type": is_ck, + "is_nested": is_nested, + } + + +def extract_namespace(detail_str: str) -> str: + """ + Extract the top-level namespace from a template detail string. + + Args: + detail_str: The template detail string + + Returns: + The top-level namespace, or empty string if none found + + Example: + >>> extract_namespace('std::basic_string') + 'std' + >>> extract_namespace('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find first :: separator + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)::", detail_str) + if match: + return match.group(1) + + # No namespace found - check if it's a simple type + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_template_name(detail_str: str) -> str: + """ + Extract the template name without namespace or parameters. + + Args: + detail_str: The template detail string + + Returns: + The template name without namespace or parameters + + Example: + >>> extract_template_name('std::basic_string') + 'basic_string' + >>> extract_template_name('ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<...>') + 'GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the last component before < or end of string + # This handles nested namespaces like ck::tensor_operation::device::DeviceConv2d + match = re.search(r"::([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + # No :: found, try to get name before < + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_full_qualified_name(detail_str: str) -> str: + """ + Extract the full qualified name (namespace::...::template_name). + + Args: + detail_str: The template detail string + + Returns: + The full qualified name without template parameters + + Example: + >>> extract_full_qualified_name('std::basic_string') + 'std::basic_string' + >>> extract_full_qualified_name('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck::tensor_operation::device::DeviceConv2d' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Match everything up to the first < or end of string + match = re.match(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def count_template_params(detail_str: str) -> int: + """ + Count the number of top-level template parameters. + + This counts commas at the top level of template brackets, + not commas inside nested templates. + + Args: + detail_str: The template detail string + + Returns: + Number of template parameters, or 0 if not a template + + Example: + >>> count_template_params('std::basic_string') + 1 + >>> count_template_params('std::tuple') + 3 + """ + if not detail_str or "<" not in detail_str: + return 0 + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return 0 + + # Track bracket depth to only count top-level commas + depth = 0 + param_count = 1 # Start with 1 (if there's a <, there's at least one param) + in_template = False + + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + in_template = True + elif char == ">": + depth -= 1 + if depth == 0: + # We've closed the outermost template + break + elif char == "," and depth == 1: + # Top-level comma + param_count += 1 + + return param_count if in_template else 0 + + +def is_ck_template(detail_str: str) -> bool: + """ + Check if this is a CK library template. + + Args: + detail_str: The template detail string + + Returns: + True if this is a CK library type, False otherwise + + Example: + >>> is_ck_template('ck::tensor_operation::device::DeviceConv2d<...>') + True + >>> is_ck_template('std::basic_string') + False + """ + if not detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Check if it starts with ck:: or contains ::ck:: + return detail_str.startswith("ck::") or "::ck::" in detail_str + + +def is_nested_template(detail_str: str) -> bool: + """ + Check if this template contains nested template instantiations. + + Args: + detail_str: The template detail string + + Returns: + True if contains nested templates, False otherwise + + Example: + >>> is_nested_template('std::vector') + False + >>> is_nested_template('std::vector') + True + """ + if not detail_str or "<" not in detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return False + + # Look for nested < after the first one + depth = 0 + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + if depth > 1: + # Found a nested template + return True + elif char == ">": + depth -= 1 + if depth == 0: + break + + return False + + +def _empty_result() -> Dict[str, any]: + """Return an empty result dictionary with default values.""" + return { + "namespace": "", + "template_name": "", + "full_qualified_name": "", + "param_count": 0, + "is_ck_type": False, + "is_nested": False, + } diff --git a/script/tools/ck-build b/script/tools/ck-build new file mode 100755 index 0000000000..2c0bb24eda --- /dev/null +++ b/script/tools/ck-build @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Build - Build Composable Kernel targets in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Build - Build Composable Kernel targets in Docker + +Usage: ck-build [options] [target...] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + -j Parallel jobs (passed to ninja) + --clean Clean before building + +Arguments: + target Target(s) to build (default: all) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-build # Build all targets + ck-build test_amdgcn_mma # Build specific target + ck-build test_amdgcn_mma test_gemm # Build multiple targets + ck-build --reconfigure # Reconfigure CMake and build all + ck-build --clean test_amdgcn_mma # Clean and build target + ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs + +EOF +} + +# Parse arguments +targets=() +reconfigure=false +clean=false +parallel_jobs="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --clean) + clean=true + shift + ;; + -j) + parallel_jobs="-j $2" + shift 2 + ;; + *) + targets+=("$1") + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Clean if requested +if [ "$clean" = true ]; then + echo "Cleaning build directory..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "" +fi + +# Build targets +if [ ${#targets[@]} -eq 0 ]; then + echo "Building all configured targets..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} 2>&1 + " +else + echo "Building targets: ${targets[*]}" + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} ${targets[*]} 2>&1 + " +fi + +echo "" +echo "Build complete ✓" diff --git a/script/tools/ck-clean b/script/tools/ck-clean new file mode 100755 index 0000000000..4b422f81f4 --- /dev/null +++ b/script/tools/ck-clean @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Clean - Clean build artifacts in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Clean - Clean build artifacts in Docker container + +Usage: ck-clean [options] + +Options: + -h, --help Show this help message + --name Specify container name + --all Remove entire build directory + -f, --force Force without confirmation + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-clean # Clean build artifacts (ninja clean) + ck-clean --all # Remove entire build directory + ck-clean --force --all # Remove build directory without confirmation + +EOF +} + +# Parse arguments +remove_all=false +force=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + remove_all=true + shift + ;; + -f|--force) + force=true + shift + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Check if container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running" + echo "Start with: ck-start" + exit 1 +fi + +# Check if build directory exists +if ! docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + echo "Build directory does not exist" + exit 0 +fi + +if [ "$remove_all" = true ]; then + # Remove entire build directory + if [ "$force" = false ]; then + read -p "Remove entire build directory? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "Removing build directory..." + docker exec "${CONTAINER_NAME}" bash -c "rm -rf /workspace/build" + echo "Build directory removed ✓" +else + # Clean with ninja + if ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Build not configured (build.ninja not found)" + echo "Use --all to remove build directory" + exit 1 + fi + + echo "Cleaning build artifacts..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "Build artifacts cleaned ✓" +fi diff --git a/script/tools/ck-exec b/script/tools/ck-exec new file mode 100755 index 0000000000..dfc7655774 --- /dev/null +++ b/script/tools/ck-exec @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Exec - Execute arbitrary commands in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Exec - Execute arbitrary commands in Docker container + +Usage: ck-exec [options] [args...] + +Options: + -h, --help Show this help message + --name Specify container name + -w Working directory (default: /workspace) + -i, --interactive Interactive mode (allocate TTY) + +Arguments: + command Command to execute (required) + args Arguments to the command + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-exec rocm-smi # Run rocm-smi + ck-exec rocminfo # Run rocminfo + ck-exec ls -la build/bin # List build binaries + ck-exec -w /workspace/build ninja -t commands # Run ninja commands + ck-exec --interactive python3 # Interactive Python session + +Common Commands: + ck-exec rocm-smi # Check GPU status + ck-exec rocminfo \| grep gfx # Check GPU architecture + ck-exec hipcc --version # Check HIP compiler version + ck-exec cmake --version # Check CMake version + ck-exec ninja -C build -t targets # List all build targets + +EOF +} + +# Parse arguments +workdir="/workspace" +interactive=false +command_args=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -w) + workdir="$2" + shift 2 + ;; + -i|--interactive) + interactive=true + shift + ;; + *) + command_args+=("$1") + shift + ;; + esac +done + +# Validate command +if [ ${#command_args[@]} -eq 0 ]; then + echo "Error: command required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Build command string +cmd_string="" +for arg in "${command_args[@]}"; do + cmd_string="${cmd_string} $(printf '%q' "$arg")" +done + +# Execute command +if [ "$interactive" = true ]; then + docker exec -it -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +else + docker exec -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +fi diff --git a/script/tools/ck-logs b/script/tools/ck-logs new file mode 100755 index 0000000000..cfad23b3b5 --- /dev/null +++ b/script/tools/ck-logs @@ -0,0 +1,134 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Logs - View container logs and build output + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Logs - View container logs and build output + +Usage: ck-logs [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -f, --follow Follow log output + -n, --tail Show last N lines (default: 100) + --cmake Show CMake configuration log + --build Show last build log + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-logs # Show last 100 lines of container logs + ck-logs -f # Follow container logs + ck-logs -n 500 # Show last 500 lines + ck-logs --cmake # Show CMake configuration + ck-logs --build # Show build log + +EOF +} + +# Parse arguments +follow=false +tail_lines=100 +show_cmake=false +show_build=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -f|--follow) + follow=true + shift + ;; + -n|--tail) + tail_lines="$2" + shift 2 + ;; + --cmake) + show_cmake=true + shift + ;; + --build) + show_build=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Check if container exists +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 1 +fi + +# Show CMake log +if [ "$show_cmake" = true ]; then + echo "CMake Configuration Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/CMakeCache.txt 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build + echo 'GPU_TARGETS:' \$(grep 'GPU_TARGETS:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_BUILD_TYPE:' \$(grep 'CMAKE_BUILD_TYPE:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_CXX_COMPILER:' \$(grep 'CMAKE_CXX_COMPILER:' CMakeCache.txt | cut -d'=' -f2) + echo 'BUILD_TESTING:' \$(grep 'BUILD_TESTING:' CMakeCache.txt | cut -d'=' -f2) + " + else + echo "CMake not configured yet" + fi + exit 0 +fi + +# Show build log (last build output) +if [ "$show_build" = true ]; then + echo "Last Build Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/.ninja_log 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c "tail -50 /workspace/build/.ninja_log" + else + echo "No build log found" + fi + exit 0 +fi + +# Show container logs +echo "Container Logs (${CONTAINER_NAME}):" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if [ "$follow" = true ]; then + docker logs -f "${CONTAINER_NAME}" +else + docker logs --tail "${tail_lines}" "${CONTAINER_NAME}" +fi diff --git a/script/tools/ck-shell b/script/tools/ck-shell new file mode 100755 index 0000000000..785c9f4d68 --- /dev/null +++ b/script/tools/ck-shell @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Shell - Open interactive shell in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Shell - Open interactive shell in Docker container + +Usage: ck-shell [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -c Execute command instead of interactive shell + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-shell # Open interactive shell + ck-shell my_container # Open shell in specific container + ck-shell -c "rocm-smi" # Execute single command + ck-shell -c "cd build && ls bin" # Execute command in build directory + +EOF +} + +# Parse arguments +command="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -c) + command="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Execute command or open shell +if [ -n "$command" ]; then + echo "Executing: ${command}" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "${command}" +else + echo "Opening shell in '${CONTAINER_NAME}' (type 'exit' to leave)..." + docker exec -it "${CONTAINER_NAME}" bash +fi diff --git a/script/tools/ck-start b/script/tools/ck-start new file mode 100755 index 0000000000..f15477492a --- /dev/null +++ b/script/tools/ck-start @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Start - Start Docker container for Composable Kernel testing + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Start - Start Docker container for Composable Kernel testing + +Usage: ck-start [options] [container_name] + +Options: + -h, --help Show this help message + --image Specify Docker image (overrides CK_DOCKER_IMAGE) + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) + +Examples: + ck-start # Start container with default name + ck-start my_ck_container # Start container with custom name + ck-start --image rocm/composable_kernel:latest + +EOF +} + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --image) + export CK_DOCKER_IMAGE="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Get Docker image +DOCKER_IMAGE=$(get_docker_image) + +# Check if container exists and is running +if container_exists "${CONTAINER_NAME}"; then + if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is already running" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + else + echo "Starting existing container '${CONTAINER_NAME}'..." + docker start "${CONTAINER_NAME}" + echo "Container started" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + fi +fi + +# Create new container +echo "Creating new Docker container '${CONTAINER_NAME}'..." +echo "Docker image: ${DOCKER_IMAGE}" +echo "Project root: ${PROJECT_ROOT}" +echo "" + +docker run -d \ + --name "${CONTAINER_NAME}" \ + --device=/dev/kfd --device=/dev/dri \ + --security-opt seccomp=unconfined \ + --group-add video \ + -v "${PROJECT_ROOT}":/workspace \ + -w /workspace \ + "${DOCKER_IMAGE}" \ + tail -f /dev/null + +echo "" +echo "Container '${CONTAINER_NAME}' started successfully" +docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + +# Show GPU info +echo "" +echo "GPU Information:" +docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -5 || echo 'No GPU detected'" diff --git a/script/tools/ck-status b/script/tools/ck-status new file mode 100755 index 0000000000..fea9de8c36 --- /dev/null +++ b/script/tools/ck-status @@ -0,0 +1,153 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Status - Check container status and information + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Status - Check container status and information + +Usage: ck-status [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + --all Show all CK containers + -v, --verbose Show detailed information + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-status # Check default container status + ck-status my_container # Check specific container + ck-status --all # Show all CK containers + ck-status -v # Show detailed information + +EOF +} + +# Parse arguments +show_all=false +verbose=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + show_all=true + shift + ;; + -v|--verbose) + verbose=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +DOCKER_IMAGE=$(get_docker_image) + +# Show all containers +if [ "$show_all" = true ]; then + echo "Composable Kernel Docker Containers:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" 2>/dev/null || echo "") + + if [ -z "$containers" ] || [ "$containers" = "NAMES STATUS CREATED AT" ]; then + echo "No CK containers found for user '${username}'" + else + echo "$containers" + fi + exit 0 +fi + +# Check specific container status +echo "Container: ${CONTAINER_NAME}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if container_is_running "${CONTAINER_NAME}"; then + echo "Status: RUNNING ✓" + echo "" + docker ps --filter "name=^${CONTAINER_NAME}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}" + + if [ "$verbose" = true ]; then + echo "" + echo "Container Details:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker inspect "${CONTAINER_NAME}" --format ' +Image: {{.Config.Image}} +Created: {{.Created}} +Platform: {{.Platform}} +Mounts: {{range .Mounts}} + - {{.Source}} -> {{.Destination}}{{end}} +' + fi + + echo "" + echo "GPU Information:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'" + + if [ "$verbose" = true ]; then + echo "" + echo "GPU Target:" + gpu_target=$(detect_gpu_target "${CONTAINER_NAME}") + echo " ${gpu_target}" + + echo "" + echo "Build Status:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + if docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo " CMake configured ✓" + echo " Build directory: /workspace/build" + + # Count built test binaries + bin_count=$(docker exec "${CONTAINER_NAME}" bash -c "ls -1 /workspace/build/bin 2>/dev/null | wc -l" || echo "0") + echo " Test binaries: ${bin_count}" + else + echo " CMake not configured" + fi + else + echo " Build directory not found" + fi + fi + +elif container_exists "${CONTAINER_NAME}"; then + echo "Status: STOPPED" + echo "" + echo "Start with: ck-start" +else + echo "Status: DOES NOT EXIST" + echo "" + echo "Create with: ck-start" +fi diff --git a/script/tools/ck-stop b/script/tools/ck-stop new file mode 100755 index 0000000000..b793f47408 --- /dev/null +++ b/script/tools/ck-stop @@ -0,0 +1,141 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Stop - Stop and remove Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Stop - Stop and remove Docker container + +Usage: ck-stop [options] [container_name] + +Options: + -h, --help Show this help message + -f, --force Force stop without confirmation + --all Stop all CK containers for this user + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-stop # Stop default container + ck-stop my_ck_container # Stop specific container + ck-stop --all # Stop all user's CK containers + ck-stop --force # Stop without confirmation + +EOF +} + +# Parse arguments +force=false +stop_all=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + -f|--force) + force=true + shift + ;; + --all) + stop_all=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Function to stop a single container +stop_container() { + local name="$1" + + if ! container_exists "${name}"; then + echo "Container '${name}' does not exist" + return 1 + fi + + echo "Stopping and removing container '${name}'..." + docker stop "${name}" 2>/dev/null || true + docker rm "${name}" 2>/dev/null || true + echo "Container '${name}' stopped and removed" +} + +# Stop all user containers +if [ "$stop_all" = true ]; then + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format '{{.Names}}') + + if [ -z "$containers" ]; then + echo "No CK containers found for user '${username}'" + exit 0 + fi + + echo "Found CK containers for user '${username}':" + echo "$containers" + echo "" + + if [ "$force" = false ]; then + read -p "Stop and remove all these containers? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "" + while IFS= read -r container; do + stop_container "$container" + done <<< "$containers" + + echo "" + echo "All containers stopped and removed" + exit 0 +fi + +# Stop single container +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 0 +fi + +# Show container info +if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is currently running" +else + echo "Container '${CONTAINER_NAME}' exists but is stopped" +fi + +# Confirm if not forced +if [ "$force" = false ]; then + read -p "Stop and remove container '${CONTAINER_NAME}'? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi +fi + +stop_container "${CONTAINER_NAME}" diff --git a/script/tools/ck-test b/script/tools/ck-test new file mode 100755 index 0000000000..712f904596 --- /dev/null +++ b/script/tools/ck-test @@ -0,0 +1,166 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Test - Build and test Composable Kernel in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Test - Build and test Composable Kernel in Docker + +Usage: ck-test [options] [test_options] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + --no-build Skip building, run test directly + +Arguments: + test_name Name of test executable (required) + test_options Additional options passed to test (e.g., --gtest_filter=*) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-test test_amdgcn_mma + ck-test test_amdgcn_mma --gtest_filter=*Fp16* + ck-test --name my_container test_amdgcn_mma + ck-test --reconfigure test_amdgcn_mma + +EOF +} + +# Parse arguments +test_name="" +reconfigure=false +no_build=false +test_options=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --no-build) + no_build=true + shift + ;; + --gtest_*|--help) + test_options+=("$1") + shift + ;; + *) + if [ -z "$test_name" ]; then + test_name="$1" + else + test_options+=("$1") + fi + shift + ;; + esac +done + +# Validate test name +if [ -z "$test_name" ]; then + echo "Error: test_name required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Build test if needed (unless --no-build is specified) +if [ "$no_build" = false ]; then + if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then + echo "Building ${test_name}..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + else + echo "Test executable found, rebuilding to ensure latest version..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + fi +fi + +# Run test +echo "Running: ${test_name} ${test_options[*]}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Build the command with proper quoting +cmd="cd /workspace/build && ./bin/${test_name}" +for opt in "${test_options[@]}"; do + cmd="${cmd} $(printf '%q' "$opt")" +done + +docker exec "${CONTAINER_NAME}" bash -c "${cmd}" +exit_code=$? + +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +if [ $exit_code -eq 0 ]; then + echo "Test completed successfully" +else + echo "Test failed with exit code: ${exit_code}" +fi + +exit $exit_code diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index ebe17aadd6..016f7be60d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3 static constexpr bool check_data_type() { using Base = TestCkTileGemmPipeline>; - if constexpr(std::is_same_v && - std::is_same_v) - { - return false; - } - else if constexpr(std::is_same_v && - std::is_same_v) + if constexpr(std::is_same_v && + std::is_same_v) { return false; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 334e360eb5..4bef581254 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 5749a8d3b2..30c4eb11f9 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -11,7 +11,24 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time - # AQuant tests - split into 6 files + # AQuant tests - split into 10 files + + # AQuant Memory Pipeline tests + add_gtest_executable(test_tile_gemm_quant_aquant_mem_prefill_interwave + test_gemm_quant_aquant_mem_prefill_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_prefill_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_intrawave + test_gemm_quant_aquant_mem_decode_intrawave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_intrawave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_interwave + test_gemm_quant_aquant_mem_decode_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr test_gemm_quant_aquant_base_rcr.cpp ) @@ -150,10 +167,21 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # Target to build only AQuant memory pipeline tests + add_custom_target(test_tile_gemm_aquant_mem_all) + add_dependencies(test_tile_gemm_aquant_mem_all + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave + ) + # Umbrella target to build all gemm quant tests add_custom_target(test_tile_gemm_quant_all) add_dependencies(test_tile_gemm_quant_all # AQuant tests + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave test_tile_gemm_quant_aquant_base_rcr test_tile_gemm_quant_aquant_base_rrr_crr test_tile_gemm_quant_aquant_base_ccr diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp new file mode 100644 index 0000000000..a7ab4120a1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp new file mode 100644 index 0000000000..483138d711 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Intrawave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeIntrawaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Intrawave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp new file mode 100644 index 0000000000..7e851d9bd3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Prefill Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemPrefillInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Prefill Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp index 133c11860a..911af678df 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPrefillTypes = ::testing::Types< // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935ef..9652dd449d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -69,6 +69,38 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefillIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigPrefillInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigDecodeIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + struct GemmConfigMxFp4 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; @@ -374,6 +406,223 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase +class TestCkTileGemmAQuantMem + : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::AQLayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + using typename Base::QuantGroupSize; + static constexpr auto QuantType = Base::QuantType; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + // AQuant-specific data generation + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = + ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{})); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{})); + // AQuant uses grouped quantization for A matrix + const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK); + // AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant) + const ck_tile::index_t stride_AQ = + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{})); + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + // AQLayout is independently specified for each test case + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + // Initialize data with random values + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f}(a_m_k); + } + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f}(aq_m_aqk); + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + // Copy to device + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor temp = a_m_k; + ck_tile::permute_vectors_i4x4_b(temp); + a_m_k_dev_buf.ToDevice(temp.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } + // aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + if constexpr(Base::GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor aq_shuffle_host = + ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales) + nullptr, // bq_ptr (not used for AQuant) + 1, // k_batch + M, + N, + K, // M, N, K + AQK, // QK_A + 0, // QK_B (not used for AQuant) + stride_A, + stride_B, + stride_C, + stride_AQ, + 0 // strides + }; + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + // Run reference AQuant implementation + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N + << ", K=" << K; + if(!pass) + { + std::cout << "AQuantGrouped - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // AQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = CodegenGemmTraits::TransposeC; + using PipelineProblem = ck_tile::GemmAQuantPipelineProblem; + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrMem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c>>; + using Kernel = ck_tile::QuantGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for AQuant kernel"); + } + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + // BQuant-specific test fixture template class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase> diff --git a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp index 98f466a2b3..3e4eb07a64 100644 --- a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp @@ -46,7 +46,7 @@ class TestConvndBwdData : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/convnd_fwd/convnd_fwd_xdl.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp index a2fdcaf870..0377b01bb2 100644 --- a/test/convnd_fwd/convnd_fwd_xdl.cpp +++ b/test/convnd_fwd/convnd_fwd_xdl.cpp @@ -47,7 +47,7 @@ class TestConvndFwd : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc index 25d95cda3d..01d7d5a5fd 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -125,7 +125,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM) TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -139,7 +139,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -153,7 +153,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -169,7 +169,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; diff --git a/test/gpu_reference/CMakeLists.txt b/test/gpu_reference/CMakeLists.txt index 443818feb3..d1c3908849 100644 --- a/test/gpu_reference/CMakeLists.txt +++ b/test/gpu_reference/CMakeLists.txt @@ -4,6 +4,9 @@ add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp) target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility) +add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp) +target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility) + add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp) target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility) diff --git a/test/gpu_reference/gpu_reference_utils.hpp b/test/gpu_reference/gpu_reference_utils.hpp index fc017c8734..88306d51a4 100644 --- a/test/gpu_reference/gpu_reference_utils.hpp +++ b/test/gpu_reference/gpu_reference_utils.hpp @@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType } } +// Forward convolution with D tensor support +template +bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& d_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& d_dev, + DeviceMem& output_dev, + OutElementOp out_element_op) +{ + using InElementOp = tensor_operation::element_wise::PassThrough; + using WeiElementOp = tensor_operation::element_wise::PassThrough; + + // Create D tensor lengths and strides for GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + d_lengths_vec[0] = params.G_; + d_lengths_vec[1] = params.N_; + d_lengths_vec[2] = params.K_; + for(index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); + } + + std::vector d_strides_vec = + ref::compute_conv_tensor_strides(d_lengths_vec, params.num_dim_spatial_); + + std::array d_ptrs = { + reinterpret_cast(d_dev.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + // Call GPU reference with D tensor + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer())}; + + ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> d_tensors_ref = {d_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + InElementOp{}, + WeiElementOp{}, + out_element_op, + {}, // A tensors + {}, // B tensors + d_tensors_ref); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + +// Forward convolution with multiple A/B tensor support +template +bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& a_extra_cpu, + const Tensor& b_extra_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& a_extra_dev, + DeviceMem& b_extra_dev, + DeviceMem& output_dev, + InElementOp in_element_op, + WeiElementOp wei_element_op) +{ + using OutElementOp = tensor_operation::element_wise::PassThrough; + + // Call GPU reference with extra A and B tensors + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer()), + reinterpret_cast(a_extra_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer()), + reinterpret_cast(b_extra_dev.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + OutElementOp{}); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> a_tensors_ref = {a_extra_cpu}; + std::array, 1> b_tensors_ref = {b_extra_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + in_element_op, + wei_element_op, + OutElementOp{}, + a_tensors_ref, + b_tensors_ref, + {}); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + } // namespace test } // namespace ck diff --git a/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp new file mode 100644 index 0000000000..ebe1e9695c --- /dev/null +++ b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "gpu_reference_utils.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using namespace ck; +using ck::test::ConvKernelType; + +// ==================== D Tensor (Bias) Tests ==================== + +template +bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::AddClamp; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(bias, bias_dev); + + // Test with AddClamp (bias operation with clamping) + AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias) +{ + auto params = test::conv_test_shapes::get_3d_small(); + bool result = test_conv_gpu_ref_with_bias<3, + float, + float, + float, + tensor_layout::convolution::GNCDHW, + tensor_layout::convolution::GKCZYX, + tensor_layout::convolution::GNKDHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g4(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== D Tensor (Bilinear) Tests ==================== + +template +bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::Bilinear; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor d_tensor(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(d_tensor, d_dev); + + // Test with Bilinear: y = alpha * conv_result + beta * d_tensor + Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bilinear<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== Multiple A/B (ScaleAdd) Tests ==================== + +template +bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::ScaleAdd; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input) + Tensor b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight) + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType)); + DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(a_extra, a_extra_dev); + test::initialize_and_copy_tensor(b_extra, b_extra_dev); + + // Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1 + ScaleAdd in_element_op(2.0f); // scale factor for input + ScaleAdd wei_element_op(1.5f); // scale factor for weight + + return test::test_conv_fwd_with_multi_ab_impl(params, + input, + weight, + a_extra, + b_extra, + input_dev, + weight_dev, + a_extra_dev, + b_extra_dev, + output_dev, + in_element_op, + wei_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp index b45f204b40..ea7289d6bf 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -63,37 +63,62 @@ class TestGroupedConvndBwdData : public ::testing::Test Tensor& out, Tensor& d) { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData(); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(in_g_n_c_wis_desc, d_lengths, d_strides); - auto ref_invoker = ref_conv.MakeInvoker(); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - Bilinear{alpha, beta}, - WeiElementOp{}, - OutElementOp{}, - {}, - {}, - d_tensors); + DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); - ref_invoker.Run(ref_argument); + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_data_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InDataType>( + static_cast(in_device_buf.GetDeviceBuffer()), + {static_cast(wei_device_buf.GetDeviceBuffer())}, + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{alpha, beta}, + WeiElementOp{}, + OutElementOp{}); + + in_device_buf.FromDevice(in_host.mData.data()); } bool PerformConvDataBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp index 84d013bca7..f1f985883c 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test void RunReference(ck::utils::conv::ConvParam& conv_param, Tensor& in_host, - Tensor& wei, - Tensor& out) + DeviceMem& wei_device_buf, + DeviceMem& out_device_buf) { - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise - Tensors*/ - {}; + // GPU reference + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization - auto ref_invoker = ref_conv.MakeInvoker(); + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{alpha}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(in_host.mData.data()); } bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) @@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); - in_device_buf.ToDevice(in_device.mData.data()); out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + RunReference(conv_param, in_host, wei_device_buf, out_device_buf); + std::array out_lengths{}; std::array out_strides{}; std::array wei_lengths{}; @@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - RunReference(conv_param, in_host, wei, out); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD& out, Tensor& d) { - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdWeight{}; + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in, - wei_host, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{}, - WeiElementOp{alpha, beta}, - OutElementOp{}, - {}, - {}, - d_tensors); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(wei_g_k_c_xs_desc, d_lengths, d_strides); - ref_invoker.Run(ref_argument); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; + + DeviceMem d_device_buf(sizeof(WeiDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); + + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_host.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_weight_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + WeiDataType>( + {static_cast(in_device_buf.GetDeviceBuffer())}, + static_cast(wei_device_buf.GetDeviceBuffer()), + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{}, + WeiElementOp{alpha, beta}, + OutElementOp{}); + + wei_device_buf.FromDevice(wei_host.mData.data()); } bool PerformConvWeightBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index bce6da4b68..5aa0b13c07 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; this->split_k_ = -1; bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); + EXPECT_TRUE(is_supported); } diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 6f8b71679c..725c5716d9 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -21,13 +21,12 @@ endif() if(GPU_TARGETS MATCHES "gfx9") if(CK_EXPERIMENTAL_BUILDER) - # TODO: Reenable after the instance fixes - # add_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) - # target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) - # if(TARGET device_grouped_conv_fwd_tile_instances) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) - # endif() + add_gtest_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) + target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) + if(TARGET device_grouped_conv_fwd_tile_instances) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) + endif() endif() endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp index 1b37f5eb4e..645aab0151 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -66,10 +66,10 @@ class TestGroupedConvndFwdBilinear : public ::testing::Test OutDataType, AComputeType, BComputeType, - IndexType>(true, // do_verification + IndexType>(2, // do_verification 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param, bilinear_op); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index 199a50f0fd..e78e61f707 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -24,6 +24,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" using I8 = int8_t; using F16 = ck::half_t; @@ -131,39 +132,34 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, wei_device_buf.ToDevice(weight.mData.data()); wei_bias_device_buf.ToDevice(weight_bias.mData.data()); - // Run reference op + // Run GPU reference if(do_verification) { - const std::array, NumAs - 1> elementwise_a_tensors = {input_bias}; - const std::array, NumBs - 1> elementwise_b_tensors = {weight_bias}; - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer()), + reinterpret_cast(in_bias_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer()), + reinterpret_cast(wei_bias_device_buf.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weight, - host_output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op, - elementwise_a_tensors, - elementwise_b_tensors); + ck::ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); - // init host output to zero - host_output.SetZero(); + HIP_CHECK_ERROR(hipDeviceSynchronize()); - ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(host_output.mData.data()); } std::string best_op_name; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp index c04a15ec98..fe517572ff 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -7,12 +7,14 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" // TODO: Remove limitation of conv fwd gpu reference which does not support right pad #define CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES 1 +// TODO: Remove this limitation after gpu reference fix +#define ENABLE_BHALF_GROUPED_CONV_FWD_TESTS 0 static ck::index_t args_mask = 0xffff; static ck::index_t instance_index = -1; @@ -67,7 +69,10 @@ class TestGroupedConvndFwdTile : public ::testing::Test auto inputs = alloc_inputs(args); auto outputs = alloc_outputs(args); - ckt::init_inputs(args, inputs.get()); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().input, args.make_input_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().weight, args.make_weight_descriptor(), -5, 5); std::cout << args.make_input_descriptor() << std::endl; std::cout << args.make_weight_descriptor() << std::endl; @@ -150,13 +155,12 @@ using KernelTypes2d = ::testing::Types, - SignatureDetails<2, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NHWGC, - ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 2, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK >> + ; +#endif using KernelTypes3d = ::testing::Types, - SignatureDetails<3, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NDHWGC, - ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 3, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK >> + ; +#endif template class TestGroupedConvndFwdTile2d : public TestGroupedConvndFwdTile diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index d1706d4cec..68a8b016e3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, false /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp index fef485a950..2c04b52b4f 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -50,7 +50,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, Clamp>( - true, // do_verification + 2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp index a78a17cbf4..78cfe126a3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, true /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp index b4179cae62..b2a9cff231 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -58,10 +58,10 @@ class TestGroupedConvndFwdScale : public ::testing::Test OutDataType, ck::tensor_operation::element_wise::Scale, InDataType, - InDataType>(true, // do_verification + InDataType>(2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param); } EXPECT_TRUE(pass); diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp index f09fd86e06..9e62b9a6c0 100644 --- a/test/util/unit_sequence.cpp +++ b/test/util/unit_sequence.cpp @@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize) EXPECT_TRUE((is_same::value)); } +TEST(SequenceGen, UniformSequenceSingleElement) +{ + using Result = typename uniform_sequence_gen<1, 99>::type; + using Expected = Sequence<99>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceDifferentValues) +{ + using Result1 = typename uniform_sequence_gen<3, 0>::type; + using Expected1 = Sequence<0, 0, 0>; + EXPECT_TRUE((is_same::value)); + + using Result2 = typename uniform_sequence_gen<4, -5>::type; + using Expected2 = Sequence<-5, -5, -5, -5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceLargeSize) +{ + // Test with larger size to verify __make_integer_seq implementation + using Result = typename uniform_sequence_gen<16, 7>::type; + using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>; + EXPECT_TRUE((is_same::value)); +} + // Test make_index_sequence TEST(SequenceGen, MakeIndexSequence) { @@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero) EXPECT_TRUE((is_same::value)); } +// Test sequence_gen with custom functors +TEST(SequenceGen, SequenceGenWithDoubleFunctor) +{ + struct DoubleFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; } + }; + using Result = typename sequence_gen<5, DoubleFunctor>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenWithSquareFunctor) +{ + struct SquareFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; } + }; + using Result = typename sequence_gen<5, SquareFunctor>::type; + using Expected = Sequence<0, 1, 4, 9, 16>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenZeroSize) +{ + struct IdentityFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i; } + }; + using Result = typename sequence_gen<0, IdentityFunctor>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); + // Also verify non-zero size works with identity + using Result5 = typename sequence_gen<5, IdentityFunctor>::type; + EXPECT_TRUE((is_same>::value)); +} + +TEST(SequenceGen, SequenceGenSingleElement) +{ + struct ConstantFunctor + { + __host__ __device__ constexpr index_t operator()(index_t) const { return 42; } + }; + using Result = typename sequence_gen<1, ConstantFunctor>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_merge TEST(SequenceMerge, MergeTwoSequences) { @@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence) EXPECT_TRUE((is_same::value)); } +TEST(SequenceMerge, MergeFourSequences) +{ + // Test the 4-sequence specialization + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2, 3>; + using Seq3 = Sequence<4, 5, 6>; + using Seq4 = Sequence<7, 8>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeFiveSequences) +{ + // Test the binary tree reduction path (5+ sequences) + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3>; + using Seq4 = Sequence<4>; + using Seq5 = Sequence<5>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeManySequences) +{ + // Test with many sequences to stress the binary tree reduction + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3, 4>; + using Seq4 = Sequence<5>; + using Seq5 = Sequence<6, 7>; + using Seq6 = Sequence<8>; + using Seq7 = Sequence<9, 10>; + using Seq8 = Sequence<11, 12>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeEmptySequences) +{ + // Test merging empty sequences + using Seq1 = Sequence<>; + using Seq2 = Sequence<1, 2>; + using Seq3 = Sequence<>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeZeroSequences) +{ + // Test the empty specialization + using Result = typename sequence_merge<>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_split TEST(SequenceSplit, SplitInMiddle) { diff --git a/tile_engine/ops/gemm/gemm_validation_utils.py b/tile_engine/ops/gemm/gemm_validation_utils.py index cae6123307..1af45f8e90 100644 --- a/tile_engine/ops/gemm/gemm_validation_utils.py +++ b/tile_engine/ops/gemm/gemm_validation_utils.py @@ -128,7 +128,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -136,7 +135,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], @@ -148,7 +146,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -156,7 +153,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], @@ -169,7 +165,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -177,7 +172,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [