Merge remote-tracking branch 'origin/jograner/bwd-weight-splitk-autodeduce' into features/grouped-conv-perf-uplift

This commit is contained in:
Ville Pietilä
2026-01-28 10:57:40 -05:00
144 changed files with 9140 additions and 2152 deletions

View File

@@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.

View File

@@ -259,6 +259,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL FP8 gemms on gfx950")
add_definitions(-DCK_USE_GFX950)
set(CK_USE_GFX950 "ON")
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)

101
Dockerfile.manylinux Normal file
View File

@@ -0,0 +1,101 @@
FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.2
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
ENV DEBIAN_FRONTEND=noninteractive
USER root
# Add rocm repository
RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y
RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \
dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \
dnf update -y && \
dnf install python3-setuptools python3-wheel -y && \
dnf install rocm-dev -y
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
ENV CK_SCCACHE=$CK_SCCACHE
RUN if [ "$CK_SCCACHE" != "" ]; then \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
fi
# Install dependencies
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
cmake \
clang-tools-extra \
gcc-c++ \
libstdc++ \
libstdc++-devel \
libstdc++-static \
git \
hip-rocclr \
jq \
mpich \
net-tools \
pkg-config \
redis \
sshpass \
stunnel \
vim \
nano \
zip \
openssh-server \
kmod && \
dnf clean all && \
rm -rf /var/lib/apt/lists/* && \
rm -rf amdgpu-install* && \
#Install latest ccache
git clone https://github.com/ccache/ccache.git && \
cd ccache && mkdir build && cd build && cmake .. && make install && \
#Install ClangBuildAnalyzer
git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \
cd ClangBuildAnalyzer/ && \
make -f projects/make/Makefile && \
cd / && \
#Install latest cppcheck
git clone https://github.com/danmar/cppcheck.git && \
cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \
cd / && \
# Install packages for processing the performance results
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
cd rocm-cmake && mkdir build && cd build && \
cmake .. && cmake --build . && cmake --build . --target install
WORKDIR /
# Add alternative compilers, if necessary
ENV compiler_version=$compiler_version
ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi

4
Jenkinsfile vendored
View File

@@ -39,10 +39,10 @@ def sendFailureNotifications() {
// Error patterns to scan build logs for specific failure types and send detailed notifications.
def failurePatterns = [
[pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"],
[pattern: /docker login failed/, description: "Docker login failed"],
[pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"],
[pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"],
[pattern: /cat: .* No such file or directory/, description: "GPU not found"],
[pattern: /GPU not found/, description: "GPU not found"],
[pattern: /(.*)GPU not found(.*)/, description: "GPU not found"],
[pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"]
]

View File

@@ -19,22 +19,22 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
PassThrough, PassThrough, PassThrough, GemmSpec,
256,
128, 256, 64,
8, 8,
16, 16,
2, 8,
S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 8, 8, 1,
1, 1,
S<1, 64, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
using GemmConfig = GemmConfigQuantDecodeInterwave<T>;
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
// template <typename T>

View File

@@ -93,6 +93,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigQuantDecodeInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
@@ -229,6 +250,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>

View File

@@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(1.0f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
@@ -659,6 +659,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
}
}
else if(init_method == 3)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
else
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
{
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
}
}
}
else if(init_method == 4)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else if(init_method == 5)
{
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
}
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
for(ck_tile::index_t row = 0;
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
++row)
{
for(ck_tile::index_t col = 0;
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
++col)
{
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
}
}
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*aq_tensor_ptr);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
}
else
{
a_m_k.SetZero();

View File

@@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm()
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
default: throw "Unknown PipelineVersion";
@@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
default: throw "Unknown GridwiseGemmPipelineVersion";
}
@@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
case PipelineVersion::V3: return ck_pipeline::v3;
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: return ck_pipeline::v5;
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";

View File

@@ -91,6 +91,13 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
@@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";

View File

@@ -7,26 +7,25 @@
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/filter_extent.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/validation.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
/// This file implements common functionality for invoking/testing grouped
/// forward convolutions created through the CK Builder API. The main item
/// of it is the ConvArgs structure - which contains a complete description
/// of it is the Args structure - which contains a complete description
/// of a convolution operation.
///
/// It is not intended that this file contains implementation details for
/// actually launching a convolution operation. As this can be done
/// through different APIs depending on the kernel (CK, CK Tile, or a
/// reference implementation), the code dealing with that is split out
/// into a separate header for each implementation.
/// into a separate header for each implementation. Nor does this file
/// deal with details for defining the data types (`Inputs` and `Outputs`)
/// for different conv directions, that is also split out into separate
/// headers to keep this one small.
namespace ck_tile::builder::test {
@@ -56,7 +55,7 @@ struct ConvTensorLengths
///
/// @see Args
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
requires ValidConvSignature<SIGNATURE>
struct Args<SIGNATURE>
{
constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim;
@@ -204,53 +203,4 @@ struct Args<SIGNATURE>
}
};
/// @brief `Inputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* weight;
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
}
};
/// @brief `Outputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* output;
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
}
};
/// @brief `init_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/conv/args.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/error.hpp"
/// This file deals with the backward weight-specific details of running grouped
/// convolution backwards weight operations. It mainly defines the data
/// structures (`Input` and `Output`), initialization, and validation. Note
/// that for this operation specifically, many of the operations are
/// implemented automatically via testing_reflect.hpp.
namespace ck_tile::builder::test {
/// @brief `Inputs` specialization for backwards weight convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* output;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
}
};
/// @brief `Outputs` specialization for backwards weight convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* weight;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("weight", args.make_weight_descriptor(), &Outputs<SIGNATURE>::weight);
}
};
/// @brief `init_inputs()` specialization for backwards convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see init_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,276 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include <type_traits>
#include <array>
/// This file contains the implementation details for invoking/testing
/// bwd grouped convolution operations in old CK. The main item is the
/// `run()` function, which is the main implementation used to invoke
/// CK grouped forward convolution kernels.
namespace ck_tile::builder::test {
namespace detail {
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
///
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except
/// with some utility aliases. For that reason, its moved to this detail
/// namespace.
template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvBwdWeightInstance = requires(Conv& conv,
const Types::InDataType* p_a,
Types::WeiDataType* p_b,
const Types::OutDataType* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde,
ck::index_t split_k) {
requires ValidConvSignature<SIGNATURE>;
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
{
conv.MakeArgument(p_a,
p_b,
p_e,
// A lengths/strides
lengths,
strides,
// B lengths/strides
lengths,
strides,
// E lengths/strides
lengths,
strides,
// strides/dilations/pads
filter,
filter,
filter,
filter,
// element-wise operations.
elementwise_a,
elementwise_b,
elementwise_cde,
split_k)
};
};
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
/// invoked like old CK.
///
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except
/// with some utility aliases. For that reason, its moved to this detail
/// namespace.
template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv,
const Types::InDataType* p_a,
Types::WeiDataType* p_b,
const Types::OutDataType* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde,
ck::index_t split_k) {
requires ValidConvSignature<SIGNATURE>;
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
{
conv.MakeArgument(p_a,
p_b,
p_e,
// TODO: Actually support multiple d
{},
// A lengths/strides
lengths,
strides,
// B lengths/strides
lengths,
strides,
// E lengths/strides
lengths,
strides,
// TODO: Multiple D lengths/strides
{},
{},
// strides/dilations/pads
filter,
filter,
filter,
filter,
// element-wise operations.
elementwise_a,
elementwise_b,
elementwise_cde,
split_k)
};
};
} // namespace detail
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance<Conv, SIGNATURE>;
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
/// invoked like old CK.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept CkConvBwdWeightMultipleDInstance =
detail::CkConvBwdWeightMultipleDInstance<Conv, SIGNATURE>;
/// @brief `run()` specialization for backward weight convolution and old CK.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
[[nodiscard]] RunResult run(CkConvBwdWeightInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
const auto copy = [](const auto& src, auto& dst) {
std::copy(src.begin(), src.end(), dst.begin());
};
const auto to_ck_lengths = [&](const auto& src) {
std::array<ck::index_t, spatial_dim + 3> result;
copy(src, result);
return result;
};
const auto to_ck_extent = [&](const auto& extent) {
std::array<ck::index_t, spatial_dim> result;
copy(extent, result);
return result;
};
const auto param = args.to_ck_conv_param();
const auto input_desc = args.make_input_descriptor();
const auto weight_desc = args.make_weight_descriptor();
const auto output_desc = args.make_output_descriptor();
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
static_cast<Types::WeiDataType*>(outputs.weight),
static_cast<const Types::OutDataType*>(inputs.output),
to_ck_lengths(input_desc.get_lengths()),
to_ck_lengths(input_desc.get_strides()),
to_ck_lengths(weight_desc.get_lengths()),
to_ck_lengths(weight_desc.get_strides()),
to_ck_lengths(output_desc.get_lengths()),
to_ck_lengths(output_desc.get_strides()),
to_ck_extent(param.conv_filter_strides_),
to_ck_extent(param.conv_filter_dilations_),
to_ck_extent(param.input_left_pads_),
to_ck_extent(param.input_right_pads_),
args.a_elementwise_op,
args.b_elementwise_op,
args.cde_elementwise_op,
args.k_batch);
if(!conv.IsSupportedArgument(ck_args))
return RunResult::not_supported("invalid ck arguments");
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
}
/// @brief `run()` specialization for backward weight convolution and old CK.
///
/// This overload is specialized for Multiple-D.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
const auto copy = [](const auto& src, auto& dst) {
std::copy(src.begin(), src.end(), dst.begin());
};
const auto to_ck_lengths = [&](const auto& src) {
std::array<ck::index_t, spatial_dim + 3> result;
copy(src, result);
return result;
};
const auto to_ck_extent = [&](const auto& extent) {
std::array<ck::index_t, spatial_dim> result;
copy(extent, result);
return result;
};
const auto param = args.to_ck_conv_param();
const auto input_desc = args.make_input_descriptor();
const auto weight_desc = args.make_weight_descriptor();
const auto output_desc = args.make_output_descriptor();
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
static_cast<Types::WeiDataType*>(outputs.weight),
static_cast<const Types::OutDataType*>(inputs.output),
{}, // TODO
to_ck_lengths(input_desc.get_lengths()),
to_ck_lengths(input_desc.get_strides()),
to_ck_lengths(weight_desc.get_lengths()),
to_ck_lengths(weight_desc.get_strides()),
to_ck_lengths(output_desc.get_lengths()),
to_ck_lengths(output_desc.get_strides()),
{}, // TODO
{}, // TODO
to_ck_extent(param.conv_filter_strides_),
to_ck_extent(param.conv_filter_dilations_),
to_ck_extent(param.input_left_pads_),
to_ck_extent(param.input_right_pads_),
args.a_elementwise_op,
args.b_elementwise_op,
args.cde_elementwise_op,
args.k_batch);
if(!conv.IsSupportedArgument(ck_args))
return RunResult::not_supported("invalid ck arguments");
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
}
} // namespace ck_tile::builder::test

View File

@@ -3,9 +3,8 @@
#pragma once
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include <type_traits>
@@ -28,9 +27,39 @@ namespace detail {
/// namespace.
template <typename Conv, auto SIGNATURE>
concept CkTileConvInstance = requires(Conv&) {
requires ValidConvSignature<SIGNATURE>;
{ Conv::BlockSize() };
};
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
InDataType* input,
WeiDataType* weight,
OutDataType* output,
const ck_tile::stream_config s_conf)
{
using Conv = std::remove_reference_t<decltype(conv)>;
const auto param = args.to_ck_tile_conv_param();
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
host_args(param, input, weight, {}, output, args.k_batch);
auto kargs = Conv::MakeKernelArgs(host_args);
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return RunResult::from_runtime(ck_tile::launch_kernel(
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
}
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
@@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance<Conv, SIGNATURE>;
/// @brief `run()` specialization for forward convolution and CK Tile.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments weren't actually valid for the
/// operation. This should be caught and reported by the testing framework.
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f if s_conf time_kernel is false).
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
std::tuple<bool, float> run(CkTileConvInstance<SIGNATURE> auto& conv,
requires ConvDirectionIsForward<SIGNATURE>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config s_conf = {})
{
using Conv = std::remove_reference_t<decltype(conv)>;
const auto param = args.to_ck_tile_conv_param();
return detail::run(conv,
args,
static_cast<const void*>(inputs.input),
static_cast<const void*>(inputs.weight),
static_cast<void*>(outputs.output),
s_conf);
}
ck_tile::GroupedConvFwdHostArgs<> host_args(
param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch);
auto kargs = Conv::MakeKernelArgs(host_args);
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
if(!Conv::IsSupportedArgument(kargs))
{
std::cout << "Not supported!";
return std::make_tuple(false, 0.f);
}
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return std::make_tuple(
true,
ck_tile::launch_kernel(
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
/// @brief `run()` specialization for backwards weight convolution and CK Tile.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config s_conf = {})
{
return detail::run(conv,
args,
static_cast<const void*>(inputs.input),
static_cast<void*>(outputs.weight),
static_cast<const void*>(inputs.output),
s_conf);
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,69 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/conv/args.hpp"
/// This file deals with the forward-specific details of running grouped
/// convolution forward operations. It mainly defines the data structures
/// (`Input` and `Output`), initialization, and validation. Note that
/// for this operation specifically, many of the operations are implemented
/// automatically via testing_reflect.hpp.
namespace ck_tile::builder::test {
/// @brief `Inputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* weight;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
}
};
/// @brief `Outputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* output;
// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
}
};
/// @brief `init_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see init_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
}
} // namespace ck_tile::builder::test

View File

@@ -3,14 +3,14 @@
#pragma once
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <type_traits>
#include <array>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations in old CK. The main item is the
/// fwd grouped convolution operations in old CK. The main item is the
/// `run()` function, which is the main implementation used to invoke
/// CK grouped forward convolution kernels.
@@ -18,10 +18,9 @@ namespace ck_tile::builder::test {
namespace detail {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
///
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except
/// with some utility aliases. For that reason, its moved to this detail
/// namespace.
template <typename Conv,
@@ -29,18 +28,21 @@ template <typename Conv,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
const void* p_a,
const void* p_b,
void* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde) {
concept CkConvFwdInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
const void* p_a,
const void* p_b,
void* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde) {
requires ValidConvSignature<SIGNATURE>;
requires ConvDirectionIsForward<SIGNATURE>;
{
conv.MakeArgument(p_a,
p_b,
@@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv,
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like old CK.
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be an "old CK" implementation - that is, whether we should
@@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv,
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
concept CkConvFwdInstance = detail::CkConvFwdInstance<Conv, SIGNATURE>;
/// @brief `run()` specialization for forward convolution and old CK.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments weren't actually valid for the
/// operation. This should be caught and reported by the testing framework.
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f if s_conf time_kernel is false).
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
[[nodiscard]] RunResult run(CkConvFwdInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,
@@ -126,6 +125,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
const auto weight_desc = args.make_weight_descriptor();
const auto output_desc = args.make_output_descriptor();
if(args.k_batch != 1)
return RunResult::not_supported("ck fwd does not support k_batch != 1");
auto ck_args = conv.MakeArgument(inputs.input,
inputs.weight,
{},
@@ -147,11 +149,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
args.cde_elementwise_op);
if(!conv.IsSupportedArgument(ck_args))
{
std::cout << "invalid argument" << std::endl;
}
return RunResult::not_supported("unsupported ck arguments");
return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf));
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf));
}
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,137 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/testing.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include <stdexcept>
#include <vector>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations using the reference implementation.
/// The main item is the `run()` function, which is the primary way to
/// invoke the reference execution mechanism.
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
/// but its made specific to the reference implementation, which is
/// invoked in a slightly different way.
namespace ck_tile::builder::test {
namespace detail {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be the reference implementation - that is, whether we should
/// invoke it like the reference kernel. This is mainly used with `run()` to
/// differentiate which implementation that should be invoked.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors.
template <typename Conv,
auto SIGNATURE,
typename InDataType,
typename WeiDataType,
typename OutDataType>
concept RefConvInstance = requires(Conv& conv,
InDataType* input,
WeiDataType* weight,
OutDataType* output,
ck::utils::conv::ConvParam param) {
requires ValidConvSignature<SIGNATURE>;
{ conv.Run(input, weight, output, param) };
};
/// @brief Generic `run` implementation for forward/backwards reference kernels.
///
/// @tparam SIGNATURE The signature of the operation to perform.
///
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f for reference).
/// @see run()
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
[[nodiscard]] RunResult
run(RefConvInstance<SIGNATURE, InDataType, WeiDataType, OutDataType> auto& conv,
const Args<SIGNATURE>& args,
InDataType* input,
WeiDataType* weight,
OutDataType* output)
{
// We don't want to compute the output dims manually, just get
// them via the existing infrastructure
const auto param = args.to_ck_conv_param();
// TODO: The reference convolution is currently missing a few features.
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
if(!args.make_input_descriptor().is_packed())
return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv");
if(!args.make_weight_descriptor().is_packed())
return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv");
if(!args.make_output_descriptor().is_packed())
return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv");
conv.Run(input, weight, output, param);
return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime.
}
} // namespace detail
/// @brief Concept for checking whether this is the reference convolution
/// forward implementation.
template <typename Conv, auto SIGNATURE>
concept RefConvFwdInstance =
detail::RefConvInstance<Conv, SIGNATURE, const void*, const void*, void*> &&
ConvDirectionIsForward<SIGNATURE>;
/// @brief `run()` specialization for forward convolution and the reference
/// forward implementation.
///
/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> &&
// TODO: Maybe we can unify this implementation for bwd/weight too?
// for now, just concern outselves with reference and see when the
// rest of the bwd/weight plumbing is there.
ConvDirectionIsForward<SIGNATURE>
[[nodiscard]] RunResult run(RefConvFwdInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
return detail::run(conv, args, inputs.input, inputs.weight, outputs.output);
}
/// @brief Concept for checking whether this is the reference convolution
/// backward weight implementation.
template <typename Conv, auto SIGNATURE>
concept RefConvBwdWeightInstance =
detail::RefConvInstance<Conv, SIGNATURE, const void*, void*, const void*> &&
ConvDirectionIsBackwardWeight<SIGNATURE>;
/// @brief `run()` specialization for forward convolution and the reference
/// backward weight implementation.
///
/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight.
/// @returns RunResult about how the operation completed (or not).
///
/// @see run()
template <auto SIGNATURE>
[[nodiscard]] RunResult run(RefConvBwdWeightInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
return detail::run(conv, args, inputs.input, outputs.weight, inputs.output);
}
} // namespace ck_tile::builder::test

View File

@@ -1,88 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include <stdexcept>
#include <vector>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations using the reference implementation.
/// The main item is the `run()` function, which is the primary way to
/// invoke the reference execution mechanism.
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
/// but its made specific to the reference implementation, which is
/// invoked in a slightly different way.
namespace ck_tile::builder::test {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be the reference implementation - that is, whether we should
/// invoke it like the reference kernel. This is mainly used with `run()` to
/// differentiate which implementation that should be invoked.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept RefConvInstance = requires(Conv& conv,
const void* input,
const void* weight,
void* output,
ck::utils::conv::ConvParam param) {
{ conv.Run(input, weight, output, param) };
};
/// @brief `run()` specialization for forward convolution and the reference
/// implementation.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments weren't actually valid for the
/// operation. This should be caught and reported by the testing framework.
///
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f for reference).
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> &&
// TODO: Maybe we can unify this implementation for bwd/weight too?
// for now, just concern outselves with reference and see when the
// rest of the bwd/weight plumbing is there.
ConvDirectionIsForward<SIGNATURE>
std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
// We don't want to compute the output dims manually, just get
// them via the existing infrastructure
const auto param = args.to_ck_conv_param();
// TODO: The reference convolution is currently missing a few features.
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
if(!args.make_input_descriptor().is_packed())
{
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
return std::make_tuple(false, 0.0f);
}
if(!args.make_weight_descriptor().is_packed())
{
std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl;
return std::make_tuple(false, 0.0f);
}
if(!args.make_output_descriptor().is_packed())
{
std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl;
return std::make_tuple(false, 0.0f);
}
conv.Run(inputs.input, inputs.weight, outputs.output, param);
return std::make_tuple(true, 0.0f);
}
} // namespace ck_tile::builder::test

View File

@@ -12,6 +12,7 @@
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck/utility/data_type.hpp"

View File

@@ -3,7 +3,11 @@
#pragma once
#include <optional>
#include <concepts>
#include <string_view>
#include <string>
#include <iosfwd>
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
@@ -288,6 +292,57 @@ ValidationReport validate(const Args<SIGNATURE>& args,
Outputs<SIGNATURE> actual,
Outputs<SIGNATURE> expected) = delete;
/// @brief This structure represents the result of a run operation.
///
/// The structure contains multiple fields with information about
/// how the operation completed (or not). See those for more info.
struct RunResult
{
/// If this value is not set to `std::nullopt`, there was a problem
/// while running the algorithm. In this case, the outputs are not
/// valid (though may be partially or completely overwritten), and
/// the optional contains a short debug message that indicates the
/// problem.
std::optional<std::string> error = std::nullopt;
/// The runtime of the kernel in milliseconds, if measured. Whether the
/// runtime is measured at all depends on the stream configuration
/// passed to run(). 0 if not measured or if there was an error. This
/// value is averaged over the total amount of runs actually done. Again,
/// this is usually configured via the stream config.
float runtime = 0.f;
/// @brief Utility function for constructing a RunResult from an unsupported operation.
///
/// @param msg A short debug message that will be included in the result.
constexpr static RunResult not_supported(std::string_view msg)
{
return RunResult{.error = std::string(msg)};
}
/// @brief Utility function for constructing a RunResult from an average runtime,
/// indicating a successful operation.
///
/// @param runtime The runtime of the kernel in milliseconds.
constexpr static RunResult from_runtime(const float runtime)
{
return RunResult{.runtime = runtime};
}
/// @brief Returns whether this algorithm executed successfully.
///
/// In this case there should be no message in `error`.
bool is_supported() const { return !this->error.has_value(); }
};
inline std::ostream& operator<<(std::ostream& os, const RunResult& result)
{
if(result.error.has_value())
return os << "invalid run (" << result.error.value() << ")";
else
return os << "successful run (" << result.runtime << " ms)";
}
/// @brief Invoke a device operation created by CK Builder.
///
/// This is the main function used to invoke a particular device operation
@@ -318,13 +373,14 @@ ValidationReport validate(const Args<SIGNATURE>& args,
/// @param outputs The output tensor data. The contents will be overwritten by
/// this function.
/// @param s_conf Stream config used to launch kernel.
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f if s_conf time_kernel is false).
/// @returns RunResult about how the operation completed (or not).
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see RunResult
template <auto SIGNATURE, typename Operation, typename StreamConf>
std::tuple<bool, float> run(Operation& operation,
[[nodiscard]] RunResult run(Operation& operation,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,

View File

@@ -5,6 +5,8 @@
#include <string_view>
#include "ck_tile/builder/testing/testing.hpp"
/// testing.hpp requires developers of a type of SIGNATURE to implement
/// quite a lot of functionality for each SIGNATURE. For example, next
/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define

View File

@@ -51,6 +51,9 @@ struct ValidationReport
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
// Max error.
double max_error;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
atomicMax(d_max_error, err);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
double max_error = 0;
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
.max_error = max_error,
});
return reports_.back().is_ok();

View File

@@ -157,6 +157,7 @@ enum class PipelineVersion
V3,
V4,
V5,
V6,
WEIGHT_ONLY
};
@@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver)
case V3: return "V3";
case V4: return "V4";
case V5: return "V5";
case V6: return "V6";
case WEIGHT_ONLY: return "WEIGHT_ONLY";
default: return "Unknown";
}

View File

@@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
)
)
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
set(BWD_WEIGHT_TESTS

View File

@@ -1,23 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.input = {.config = {.layout = GNWC}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
.output = {.config = {.layout = GNWK}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
@@ -30,14 +37,58 @@ constexpr auto ALGORITHM =
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"GNWC,GKXC,GNWK",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v2"});
}
TEST(BwdWeight_1DBf16_CShuffle_V3, Execution)
{
if(!ck_tile::get_device_name().starts_with("gfx9"))
{
// Note: XDL kernel
GTEST_SKIP() << "unsupported architecture";
}
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 16,
.groups = 1,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 64},
.filter = {.width = 1},
},
.filter_strides = {.width = 1},
.filter_dilation = {.width = 1},
.input_left_pad = {.width = 0},
.input_right_pad = {.width = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -4,8 +4,9 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
@@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
@@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
"MNKPadding"});
}
TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
TEST(Fwd2DFp16_CShufV3_GNHWC, Execution)
{
if(!ck_tile::get_device_name().starts_with("gfx9"))
{
// Note: XDL kernel
GTEST_SKIP() << "unsupported architecture";
}
@@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
ckt::run(ref_conv, args, inputs.get(), reference.get());
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -1,35 +1,47 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "testing_utils.hpp"
namespace {
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using namespace ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NHWGC}},
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NHWGK}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_2D_FP16_NHWGC, Create)
{
constexpr ConvSignature BwdWeightConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto BwdWeightConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(TileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(TileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
run_ck_tile_test<Builder>({
cku::run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",
"NHWGC_GKYXC_NHWGK",
@@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_
});
}
} // namespace
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
{
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -4,8 +4,8 @@
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
@@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::FORWARD,
@@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd)
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
ckt::run(ref_conv, args, inputs.get(), reference.get());
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get()));
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -5,11 +5,14 @@
#include "testing_utils.hpp"
namespace ckt = ck_tile::builder::test;
using ck_tile::test::HipError;
using ck_tile::test::HipSuccess;
using ck_tile::test::InstanceMatcher;
using ck_tile::test::InstanceSet;
using ck_tile::test::StringEqWithDiff;
using ck_tile::test::SuccessfulRun;
TEST(InstanceSet, FromFactory)
{
@@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic)
EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue)));
EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue)));
}
TEST(RunResultMatcher, Basic)
{
EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun());
EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun()));
}
TEST(RunResultMatcher, ExplainMatchResult)
{
testing::StringMatchResultListener listener;
EXPECT_TRUE(!ExplainMatchResult(
SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener));
EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error"));
}

View File

@@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const
return ::testing::MakeMatcher(new HipStatusMatcher(error));
}
bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual,
::testing::MatchResultListener* listener) const
{
if(actual.error.has_value() && listener)
*listener << "run failed: " << actual.error.value();
return actual.is_supported();
}
void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; }
void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; }
::testing::Matcher<builder::test::RunResult> SuccessfulRun()
{
return ::testing::MakeMatcher(new RunResultMatcher());
}
} // namespace ck_tile::test

View File

@@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
/// @param error The error to expect.
::testing::Matcher<hipError_t> HipError(hipError_t error);
/// @brief RunResult matcher
///
/// `ckt::run` returns a RunResult which indicates whether there was any
/// problem while running the algorithm. This matcher is used to match those
/// values.
struct RunResultMatcher : public ::testing::MatcherInterface<builder::test::RunResult>
{
bool MatchAndExplain(builder::test::RunResult actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
};
/// @brief Construct a Google Test matcher that checks that a ckt::run result
/// was successful.
::testing::Matcher<builder::test::RunResult> SuccessfulRun();
template <auto SIGNATURE>
struct ReferenceOutputMatcher
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
@@ -180,6 +197,21 @@ struct ReferenceOutputMatcher
if(listener->IsInterested() && !errors.empty())
{
*listener << errors.size() << " tensors failed to validate";
for(const auto& e : errors)
{
*listener << "\n - " << e.tensor_name << ": ";
if(e.is_all_zero())
*listener << "all elements in actual and expected tensors are zero";
else
{
// Round to 2 digits
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
*listener << e.wrong_elements << "/" << e.total_elements
<< " incorrect elements (~" << percentage << "%)";
}
}
}
return errors.empty();

View File

@@ -3,7 +3,7 @@
#include "impl/conv_signature_types.hpp"
#include "testing_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>

View File

@@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect)
testing::StringMatchResultListener listener;
EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener));
EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate"));
EXPECT_THAT(listener.str(),
StringEqWithDiff( //
"1 tensors failed to validate\n"
" - a: 625/625 incorrect elements (~100%)"));
}

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>

View File

@@ -1,5 +1,6 @@
#include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;

View File

@@ -2,8 +2,6 @@
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
auto conv = Instance{};
bool is_supported;
float avg_time;
std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf);
return std::make_tuple(is_supported, avg_time, conv.GetInstanceString());
auto conv = Instance{};
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());

View File

@@ -0,0 +1,225 @@
# Build Time Optimization
Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575)
This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase.
## Why Build Time Matters
Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time.
## Key Types
This codebase uses compile-time types to enable zero-overhead abstractions:
- `Number<N>` - compile-time integer, enables static dispatch and compile-time arithmetic
- `Sequence<Is...>` - compile-time integer sequence, used for dimension ordering and index manipulation
- `Tuple<Ts...>` - heterogeneous container holding different types, used for tensor descriptors and transforms
These types allow the compiler to fully unroll loops, eliminate branches, and inline all operations - producing GPU kernels with no runtime abstraction cost.
## Optimization Techniques
### 1. Replace Recursive Templates with Pack Expansion
Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next:
```
sequence_gen_impl<5, F>
→ sequence_gen_impl<4, F>
→ sequence_gen_impl<3, F>
→ ...
```
Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation.
**Before** (O(N) recursive instantiation):
```cpp
template <index_t N, typename F, index_t... Is>
struct sequence_gen_impl
{
using type = typename sequence_gen_impl<N-1, F, F{}(Number<N-1>{}), Is...>::type;
};
template <typename F, index_t... Is>
struct sequence_gen_impl<0, F, Is...>
{
using type = Sequence<Is...>;
};
```
**After** (constant depth using compiler intrinsic + pack expansion):
```cpp
namespace detail {
template <typename T, T... Is>
struct sequence_gen_helper
{
// Apply functor F to all indices via pack expansion
// F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number<N-1>{})
template <typename F>
using apply = Sequence<F{}(Number<Is>{})...>;
};
} // namespace detail
template <index_t N, typename F>
struct sequence_gen
{
// __make_integer_seq<sequence_gen_helper, index_t, N> produces
// sequence_gen_helper<index_t, 0, 1, ..., N-1> with constant depth
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, N>::template apply<F>;
};
```
Note: This document assumes C++17 or later. While `std::make_integer_sequence` (introduced in C++14) is the standard library facility for generating integer sequences, it only produces `std::integer_sequence<T, ...>`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly.
### 2. Replace Lambdas with Named Functors
Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses.
**Before** (lambda creates unique instantiations at each call site):
```cpp
// The lambda inside transform_tensor_descriptor:
generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{});
```
**After** (named functor shares instantiations):
```cpp
// Define functor once
struct generate_identity_sequence
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I>) const
{
return Sequence<I>{};
}
};
// Use everywhere - shares instantiations
generate_tuple(generate_identity_sequence{}, Number<N>{});
```
This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction).
**Example: container_concat**
```cpp
// Before: lambda creates unique type per call site
// (unpack2 applies a functor to all elements from both tuples)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2([](auto&&... zs) { return make_tuple(forward<decltype(zs)>(zs)...); }, tx, ty);
}
// After: named functor shares instantiations
struct make_tuple_functor
{
template <typename... Ts>
__host__ __device__ constexpr auto operator()(Ts&&... xs) const
{
return make_tuple(forward<Ts>(xs)...);
}
};
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(make_tuple_functor{}, tx, ty);
}
```
This reduced `container_concat` instantiations from 186 to 93 (50% reduction).
**Example: make_uniform_tuple**
For patterns that create tuples with repeated values:
```cpp
// Before: unique lambda type at each call site
generate_tuple([](auto) { return some_value; }, Number<N>{});
// After: dedicated helper function
template <index_t N, typename T>
__host__ __device__ constexpr auto make_uniform_tuple(T&& value)
{
return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{});
}
// Usage
make_uniform_tuple<N>(some_value);
```
### 3. Use Constexpr Loops Instead of Template Recursion
Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation.
**Before** (O(N) template instantiations):
```cpp
// Simplified example - actual CK code used more complex recursive patterns
template <index_t Target, typename Seq, index_t Pos, bool AtEnd>
struct find_source_index_impl
{
static constexpr index_t value =
(Seq::template At<Pos>() == Target) ? Pos : find_source_index_impl<Target, Seq, Pos+1, (Pos+1 == Seq::Size())>::value;
};
template <index_t Target, typename Seq, index_t Pos>
struct find_source_index_impl<Target, Seq, Pos, true>
{
static constexpr index_t value = -1; // not found
};
```
**After** (single instantiation with constexpr loop):
```cpp
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
{
// Simplified example - actual implementation handles empty sequences
constexpr index_t values[] = {Is...};
for(index_t i = 0; i < sizeof...(Is); ++i)
if(values[i] == Target) return i;
return -1; // not found
}
```
This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%.
### 4. Use Fold Expressions for Accumulation
Fold expressions (C++17) can replace recursive template patterns for accumulation operations.
**Before** (uses helper utilities that hide template recursion: `generate_tuple` recursively constructs a tuple of N elements, and `container_reduce` recursively reduces that tuple):
```cpp
const auto element_space_size = container_reduce(
generate_tuple([&](auto i) {
return (lengths[i] - Number<1>{}) * strides[i];
}, Number<N>{}),
math::plus{}, Number<1>{});
```
**After** (single fold expression):
```cpp
template <typename... Lengths, typename... Strides, index_t... Is>
__host__ __device__ constexpr auto compute_element_space_size(
const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides,
Sequence<Is...>)
{
return (LongNumber<1>{} + ... +
((lengths[Number<Is>{}] - Number<1>{}) * strides[Number<Is>{}]));
}
```
This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%.

View File

@@ -55,9 +55,6 @@
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_TF32
#define CK_ENABLE_TF32 "ON"
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
@@ -88,10 +85,6 @@
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_TF32
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif

View File

@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
// check if src element is valid
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
// Vector length of elementwise operation
constexpr auto get_elem_op_vec_len = []() {
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
using dst_vector_t = typename dst_vector_type::type;
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
dst_vector_type op_r_v;
// Load data from memory in src_vector first
src_vector_container src_vector =
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
src_coord_.GetOffset(), true)};
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
src_vector_container src_vector = src_vector_container{
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
// apply the src elementwise op and convert to DstData under the hood if needed
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
// store result in dvgpr_ (static array holding loaded data).
// At this point data is already converted to DstData type and
// the elementwise operation has been applied
dvgpr_.template SetAsType<dst_vector_t>(
vgpr_data_idx_seq,
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
constexpr auto ordered_fwd_step = StepsPerIteration{};
// OOB check
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// calculate src data index and make sequence
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}(
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
}();
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
[&](auto i) {
if constexpr(i.value < src_data_idx.Size())
{
return Number<src_data_idx[i]>{};
}
else
{
return Number<0>{};
}
},
Number<src_data_idx.Size() + 1>{});
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
const bool is_src_valid =
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
});
// make forward steps
// forward step for each iteration just add 1
const auto dst_forward_steps = generate_tuple(
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
true,
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
// For each dimension move fwd, bwd or don't move
static_for<0, nDim, 1>{}([&](auto i) {
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto access_lengths_as_tuple =
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
}
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
decltype(thread_data_scratch_desc_),
true>;
ThreadScratchData dvgpr_;
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool,
1,
decltype(src_oob_thread_scratch_desc_),
true>;
ThreadScratchData src_dvgpr_;
ThreadScratchData dst_dvgpr_;
OOBThreadScratch oob_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
// check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),

View File

@@ -3,77 +3,21 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
Tuple<>{}, // p_d0s_grid
arg.p_b1_grid + b1_batch_offset,
Tuple<>{}, // p_d1s_grid
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.b1_grid_desc,
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.c_element_op,
arg.block_2_ctile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
@@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
index_t BatchStrideB1_;
index_t BatchStrideC_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType
Tuple<>, // D1sDataType
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
Tuple<>, // Ds0GridDesc
B1GridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
Tuple<>, // Ds1GridDesc
CGridDesc_M_N,
typename DeviceGemmGemmCommonBase::CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType,
Tuple<>, // D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
const B1DataType* p_b1_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_b1_grid{p_b1_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
c_g_m_o_lengths = arr3{batch_count, M, O};
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
const B1DataType* p_b1_grid;
CDataType* p_c_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
arr3 c_g_m_o_lengths;
arr3 c_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B layout must be Column\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b0,
@@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
std::array<const void*, DeviceGemmGemmCommonBase::NumD0Tensor> p_d0_grid{};
std::array<const void*, DeviceGemmGemmCommonBase::NumD1Tensor> p_d1_grid{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD0Tensor> StrideD0s{}, BatchStrideD0s{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD1Tensor> StrideD1s, BatchStrideD1s{};
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0_grid,
static_cast<const B1DataType*>(p_b1),
p_d1_grid,
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -0,0 +1,902 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <iostream>
#include <cstdarg>
#include <type_traits>
#include <utility>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp,
typename GridwiseOp,
bool HasMainKBlockLoop,
TailNumber TailNum,
bool IsMultiD>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx)));
auto [p_d0s_grid, p_d1s_grid] = [&]() {
if constexpr(IsMultiD)
{
auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) {
static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) {
const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In));
grid_pointer(In) = arg_grid(In) + batch_offset;
});
return std::move(grid_pointer);
};
auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx);
};
auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx);
};
auto d0s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD0Tensor>{},
get_d0_base_ptr,
arg.p_d0s_grid,
GridwiseOp::MakeD0sGridPointer());
auto d1s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD1Tensor>{},
get_d1_base_ptr,
arg.p_d1s_grid,
GridwiseOp::MakeD1sGridPointer());
return std::make_pair(d0s_grid, d1s_grid);
}
else
{
return std::make_pair(Tuple<>{}, Tuple<>{});
}
}();
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_c_e1_grid + c_e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common
{
static constexpr ck::index_t NumD0Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD0Tensor;
}
return 0;
}();
static constexpr ck::index_t NumD1Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD1Tensor;
}
return 0;
}();
struct GridDescriptorCreator
{
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
};
using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {}));
using D0sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {}));
using D1sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {}));
using CGridDesc_M_N =
decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_E1_(BatchStrideC)
{
}
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideC_E1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_E1_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideC_E1_;
};
};
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg
{
using GridwiseGemm = typename DeviceOp::GridwiseOp;
using Common =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
CE1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CE1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
IsMultiD>;
static constexpr auto NumD0Tensor = Common::NumD0Tensor;
static constexpr auto NumD1Tensor = Common::NumD1Tensor;
struct Argument : public BaseArgument
{
using arr3 = std::array<ck::index_t, 3>;
Argument(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
CE1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_c_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths,
a_g_m_k_strides);
b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths,
b0_g_n_k_strides);
b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths,
b1_g_o_n_strides);
c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor(
e1_g_m_o_lengths, e1_g_m_o_strides);
c_e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_e1_grid_desc_m_n);
block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1);
if constexpr(IsMultiD)
{
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
});
// D0 desc
d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor(
d0s_g_m_n_lengths, d0s_g_m_n_strides);
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
});
// D1 desc
d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor(
d1s_g_m_o_lengths, d1s_g_m_o_strides);
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc);
}
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseGemm::D1sGridPointer p_d1s_grid;
CE1DataType* p_c_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
typename Common::AGridDesc a_grid_desc;
typename Common::B0GridDesc b0_grid_desc;
std::conditional_t<IsMultiD, typename Common::D0sGridDesc, Tuple<>> d0s_grid_desc;
typename Common::B1GridDesc b1_grid_desc;
typename Common::D1sGridDesc d1s_grid_desc;
std::conditional_t<
IsMultiD,
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
Tuple<>>
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
std::conditional_t<IsMultiD, typename Common::E1GridDesc, typename Common::CGridDesc_M_N>
c_e1_grid_desc_m_n;
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map;
typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tail_num = decltype(tail_number)::value;
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tail_num,
IsMultiD>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
}
static bool IsSupportedArgument(const Argument& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CE1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if constexpr(IsMultiD)
{
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
else
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.e1_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -3,91 +3,20 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx)));
auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer();
auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer();
static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset;
});
static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) {
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset;
});
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_e1_grid + e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes:
// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...)
// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...)
@@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using D0sGridDesc = remove_cvref_t<decltype(MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using D1sGridDesc = remove_cvref_t<decltype(MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(MakeE1GridDescriptor({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideE1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
}
template <index_t I>
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
D0sGridDesc,
B1GridDesc,
D1sGridDesc,
E1GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
typename DeviceGemmGemmCommonBase::D0sGridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
typename DeviceGemmGemmCommonBase::D1sGridDesc,
typename DeviceGemmGemmCommonBase::E1GridDesc,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
E1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides);
e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n);
block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1);
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
// D0 desc
d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]);
});
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
// D1 desc
d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]);
});
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseOp::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseOp::D1sGridPointer p_d1s_grid;
E1DataType* p_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
D0sGridDesc d0s_grid_desc;
B1GridDesc b1_grid_desc;
D1sGridDesc d1s_grid_desc;
typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
E1GridDesc e1_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<E1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3<DeviceOp,
GridwiseOp,
has_loop,
tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeArgument(const ADataType* p_a0,
const B0DataType* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
@@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
{
return RawArg{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
return Argument{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
// polymorphic
@@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}

View File

@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemmWelford::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},

View File

@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() &&
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},

View File

@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
};

View File

@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
return GridwiseGemm::CheckValidity(
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
}

View File

@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
BlkGemmPipelineVer,
AComputeType,
BComputeType,
false,
false>;
false, // PermuteA
false, // PermuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
#define GridwiseGemmCTransposeTemplateParameters \
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
AComputeType, false, false
AComputeType, false, false, false, true
using GridwiseGemmCTranspose =
std::conditional_t<CTranspose,

View File

@@ -162,7 +162,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
}
else
{
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
@@ -171,9 +170,11 @@ struct DeviceGroupedConvBwdWeight_Explicit
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -338,16 +339,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if constexpr(!IsTwoStageNeeded)
{
if(arg.k_batch_ < 0)
{
return false;
}
}
#endif
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())

View File

@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return;
}
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO: implement
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
@@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -568,7 +568,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -585,6 +584,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -594,7 +596,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1373,12 +1374,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
// check device
if constexpr(DirectLoad)

View File

@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
bool supported = true;
for(index_t i = 0; i < arg.group_count_; ++i)
{
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
std::array<index_t, NumDTensor> stride_Ds;
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());

View File

@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)

View File

@@ -1631,6 +1631,13 @@ struct ConvInvscale
e = type_convert<f8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
e = type_convert<f8_t>(c_float / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
@@ -1656,6 +1663,13 @@ struct ConvScale
e = type_convert<f8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
e = type_convert<f8_t>(c_float * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
@@ -1683,6 +1697,15 @@ struct ConvScaleRelu
e = type_convert<f8_t>(x * scale_out_);
};
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
{
const float c_float = type_convert<float>(c);
float x;
Relu{}.template operator()<float>(x, c_float * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;

View File

@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
index_t,
index_t)
{
// Notes: padding is currently not supported with transpose
static_assert(!((PadMN || PadK) && ABDoTranspose),
"padding is currently not supported with transpose");
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
const index_t K_grid = !PadK ? sizeK : KPad;

View File

@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.wave_size;
__host__ __device__ static constexpr bool AWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
!IsBPreShuffled;
}
__host__ __device__ static constexpr bool BWaveTransferApplicable()
{
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
}
// Limitations of the current implementation:
// - no multiAB
// - GemmSpecialization Default with transpose
#ifdef __gfx12__
static constexpr bool IsAWaveTransferApplicable =
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
static constexpr bool IsBWaveTransferApplicable =
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
static constexpr bool IsWaveTileInterleavedFitting =
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Conditions for Wave Transfer with transpose:
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
{
if constexpr(AWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
pass &= !(sizeof(ADataType_) == 1 &&
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
{
if constexpr(BWaveTransferApplicable() &&
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
{
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
{
return false;
}
bool pass = true;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
pass &= !(sizeof(BDataType_) == 1 &&
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
});
return pass;
}
else
{
return true;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg,

View File

@@ -199,55 +199,113 @@ template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
// merge sequence - optimized to avoid recursive instantiation
//
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why:
//
// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each
// element can be computed independently: output[i] = f(i)
//
// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths.
// To compute output[i], we need to know:
// 1. Which input sequence contains this index
// 2. The offset within that sequence
// This requires computing cumulative sequence lengths, which requires recursion/iteration.
//
// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth:
// - Base cases handle 1-4 sequences directly (O(1) for common cases)
// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...)
// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences
//
// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to
// linear dependency chain, so binary tree is superior.
//
namespace detail {
// Helper to concatenate multiple sequences in one step using fold expression
template <typename... Seqs>
struct sequence_merge_impl;
// Base case: single sequence
template <index_t... Is>
struct sequence_merge_impl<Sequence<Is...>>
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
using type = Sequence<Is...>;
};
// Two sequences: direct concatenation
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};
template <typename Seq>
struct sequence_merge<Seq>
// Three sequences: direct concatenation (avoids one level of recursion)
template <index_t... Xs, index_t... Ys, index_t... Zs>
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
{
using type = Seq;
using type = Sequence<Xs..., Ys..., Zs...>;
};
// generate sequence
// Four sequences: direct concatenation
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
{
using type = Sequence<As..., Bs..., Cs..., Ds...>;
};
// General case: binary tree reduction (O(log N) depth instead of O(N))
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
{
// Merge pairs first, then recurse
using left = typename sequence_merge_impl<S1, S2>::type;
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
using type = typename sequence_merge_impl<left, right>::type;
};
} // namespace detail
template <typename... Seqs>
struct sequence_merge
{
using type = typename detail::sequence_merge_impl<Seqs...>::type;
};
template <>
struct sequence_merge<>
{
using type = Sequence<>;
};
// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
namespace detail {
// Helper that applies functor F to indices and produces a Sequence
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
// ..., N-1>
template <typename T, T... Is>
struct sequence_gen_helper
{
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
template <typename F>
using apply = Sequence<F{}(Number<Is>{})...>;
};
} // namespace detail
template <index_t NSize, typename F>
struct sequence_gen
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
};
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
using type = typename sequence_gen_impl<0, NSize, F>::type;
template <typename F>
struct sequence_gen<0, F>
{
using type = Sequence<>;
};
// arithmetic sequence
@@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
};
// uniform sequence
// uniform sequence - optimized using __make_integer_seq
namespace detail {
template <typename T, T... Is>
struct uniform_sequence_helper
{
// Apply a constant value to all indices via pack expansion
template <index_t Value>
using apply = Sequence<((void)Is, Value)...>;
};
} // namespace detail
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
template apply<I>;
};
using type = typename sequence_gen<NSize, F>::type;
template <index_t I>
struct uniform_sequence_gen<0, I>
{
using type = Sequence<>;
};
// reverse inclusive scan (with init) sequence

View File

@@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
using type = Tuple<Xs..., Ys...>;
};
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
{

View File

@@ -19,7 +19,7 @@
namespace ck_tile {
/** @brief Maximum number of error values to display when checking errors */
constexpr int ERROR_DETAIL_LIMIT = 128;
constexpr int ERROR_DETAIL_LIMIT = 16;
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;

View File

@@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
else
return make_static_tile_distribution(
tile_distribution_encoding< //
tile_distribution_encoding<
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,

View File

@@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
CWarpTensor c_warp_tensor;
// for every column in AQ
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// for every warp corresponding to a quantization scale
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
@@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
make_static_tile_distribution(MakeABlockDistributionEncode());
static constexpr auto BLdsTileDistr =
make_static_tile_distribution(MakeBBlockDistributionEncode());
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
template <index_t KIdx,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
constexpr auto a_lds_load_distr = [&]() {
if constexpr(ALoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeABlockDistributionEncode()),
ADataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeABlockDistributionEncode());
}();
constexpr auto b_lds_load_distr = [&]() {
if constexpr(BLoadTranspose)
return make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(MakeBBlockDistributionEncode()),
BDataType>::TransposedDstrEncode{});
else
return make_static_tile_distribution(MakeBBlockDistributionEncode());
}();
constexpr auto a_lds_shape = []() {
if constexpr(ALoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
else
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(BLoadTranspose)
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
else
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
}();
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
constexpr auto a_offset =
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
constexpr auto b_offset =
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
auto a_lds_gemm_window = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
auto b_lds_gemm_window = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
a_warp_tile_, a_lds_gemm_window);
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
b_warp_tile_, b_lds_gemm_window);
}
// C += A * B with quantization support
template <typename CBlockTensor,
typename AQBlockTensor,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
AQBlockTensor& aq_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as corresponding "
"C block tensor data type!");
constexpr auto warp_size = get_warp_size();
// Track which KRepeat chunk is currently loaded
index_t current_k_repeat_loaded = -1;
// Restructured loop: M → N → QScale → KIterPerQScale
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// Iterate over quantization groups
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
CWarpTensor c_warp_tensor;
// Accumulate K iterations for this quantization group
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
// Map quantization indices to global K iteration
constexpr auto kIterGlobal =
kQScale * Traits::KIterPerQScale + kIterInQScale;
// Map to KRepeat chunk and KInnerLoopIter offset
constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter;
constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter;
// Prefetch new chunk if needed
if constexpr(kInnerIdx == 0)
{
if(current_k_repeat_loaded != kRepeatIdx)
{
LocalPrefetch<kRepeatIdx>(
a_block_window, b_block_window, a_load_tr, b_load_tr);
__builtin_amdgcn_sched_barrier(0);
if constexpr(kRepeatIdx != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
current_k_repeat_loaded = kRepeatIdx;
}
}
// Load A warp tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() =
a_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kInnerIdx>{},
a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// Load B warp tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() =
b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kInnerIdx>{},
b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Synchronization barrier at the end of last iteration
if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 &&
kIterInQScale == Traits::KIterPerQScale - 1 &&
mIter.value == MIterPerWarp - 1 &&
nIter.value == NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// Accumulate: first iteration initializes, rest accumulate
if constexpr(kIterInQScale == 0)
{
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
}
else
{
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
}
// Set priority for scheduling
if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
// Apply quantization scale after accumulating all K iterations for this
// group
constexpr auto tbuf_offset =
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
merge_sequences(sequence<mIter, nIter>{},
c_warp_y_index_zeros)) /
CBlockTensor::PackedSize>{};
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
aq_block_tensor);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
float scale_reg_f = aq_picker.template pick<c_row>();
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
@@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr
MakeCBlockTile();
}
template <typename ASmemBlockWindow,
template <index_t KIdx = 0,
typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
@@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
if constexpr(Scheduler == GemmPipelineScheduler::Interwave)
{
block_gemm_impl_.template LocalPrefetch<KIdx>(
a_block_window, b_block_window, a_load_tr, b_load_tr);
}
else
{
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
}
// C += A * B

View File

@@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const OverrideADataType& a) { return a; },
[](const BDataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
aq_dram_block_window_tmp,

View File

@@ -392,8 +392,4 @@ struct BlockReduce2D
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
} // namespace ck_tile

View File

@@ -40,7 +40,7 @@ struct BlockSoftmax2D
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
auto reduce_row_max = BlockReduce2D<decltype(x)>{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else

View File

@@ -10,49 +10,55 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized backward data convolution kernel working with packed (contiguous) tensors
// Computes gradients w.r.t. input from output gradients and weights
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
// Optimized backward data convolution kernel working with packed (contiguous) tensors with
// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes
// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial]
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (output gradient) tensors
index_t NumBExtra, // Number of extra B (weight) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
const OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in,
const WeiDataType* const* __restrict__ p_weis,
const OutDataType* const* __restrict__ p_outs,
const DDataType* const* __restrict__ p_ds,
const index_t* const* __restrict__ p_d_strides,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t x = 0; x < X; ++x)
{
@@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t wo = w_tmp / stride_x;
if(wo >= 0 && wo < Wo)
{
const OutDataType* out_gnk = out_gn;
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
// Pointers at current filter position
const OutDataType* output_grad_g_n_k = output_grad_g_n;
const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c;
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnk[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
// Handle output gradient element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
out_val,
out_op,
output_grad_g_n_k,
p_outs + 1,
g * out_stride_g + n * out_stride_n,
k * out_stride_k + wo);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_g_k_c,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
}
}
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
}
}
@@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t y = 0; y < Y; ++y)
{
@@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
// Pointers at current spatial height and filter Y position
const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h;
const WeiDataType* weight_at_c_y =
weight_g + c * wei_stride_c + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
@@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
// Handle output gradient element-wise operation with extra
// A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
out_val,
out_op,
output_grad_at_h,
p_outs + 1,
g * out_stride_g + n * out_stride_n + ho * out_stride_h,
k * out_stride_k + wo);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_c_y,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c + y * wei_stride_y,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
@@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(in_val,
in_op,
acc,
p_ds,
p_d_strides,
g,
n,
c,
hi * p_d_strides[0][3] +
wi * p_d_strides[0][4]);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
in_val;
}
@@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t z = 0; z < Z; ++z)
{
@@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t do_idx = d_tmp / stride_z;
if(do_idx >= 0 && do_idx < Do)
{
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
// Pointers at current spatial depth
const OutDataType* output_grad_at_d =
output_grad_g_n + do_idx * out_stride_d;
const WeiDataType* weight_at_c_z =
weight_g + c * wei_stride_c + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
@@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
// Pointers at current spatial depth and height
const OutDataType* output_grad_at_d_h =
output_grad_at_d + ho * out_stride_h;
const WeiDataType* weight_at_c_z_y =
weight_at_c_z + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
@@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val,
out_gnkdh[k * out_stride_k + wo]);
wei_op(wei_val,
wei_gkczy[k * wei_stride_k + x]);
// Handle output gradient element-wise operation
// with extra A tensors
detail::apply_multi_tensor_elementwise_op<
NumAExtra>(out_val,
out_op,
output_grad_at_d_h,
p_outs + 1,
g * out_stride_g +
n * out_stride_n +
do_idx * out_stride_d +
ho * out_stride_h,
k * out_stride_k + wo);
// Handle weight element-wise operation with
// extra B tensors
detail::apply_multi_tensor_elementwise_op<
NumBExtra>(
wei_val,
wei_op,
weight_at_c_z_y,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c +
z * wei_stride_z + y * wei_stride_y,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
@@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
in_val,
in_op,
acc,
p_ds,
p_d_strides,
g,
n,
c,
di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
hi * in_stride_h + wi] = in_val;
}
}
}
// GPU reference backward data convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference backward data convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -288,15 +382,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_data(TIn* p_in,
const TWei* p_wei,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility
void naive_conv_bwd_data_multi_abd(
TIn* p_in,
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
const std::array<const TOut*, NumAElementwise + 1>& p_outs,
const std::array<const TD*, NumDElementwise>& p_ds,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in,
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
std::vector<SimpleDeviceMem> wei_packed_bufs;
wei_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
}
std::vector<SimpleDeviceMem> out_packed_bufs;
out_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
out_packed_bufs.emplace_back(out_total * sizeof(TOut));
}
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
}
std::array<TOut*, NumAElementwise + 1> p_outs_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_outs_packed[i] = static_cast<TOut*>(out_packed_bufs[i].GetDeviceBuffer());
}
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
@@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in,
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
constexpr int block_size = 256;
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total);
}
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
}
// Prepare D tensor stride arrays on device
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
TOut** d_outs_ptrs = static_cast<TOut**>(outs_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
p_weis_packed.data(),
(NumBElementwise + 1) * sizeof(TWei*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs,
p_outs_packed.data(),
(NumAElementwise + 1) * sizeof(TOut*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in,
if(ndim == 1)
{
naive_conv_bwd_data_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_bwd_data_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in,
}
else // 3D
{
naive_conv_bwd_data_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_bwd_data - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void naive_conv_bwd_data(TIn* p_in,
const TWei* p_wei,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TWei*, 1> p_weis = {p_wei};
std::array<const TOut*, 1> p_outs = {p_out};
std::array<const TIn*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in,
p_weis,
p_outs,
p_ds,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -10,49 +10,58 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
// Optimized backward weight convolution kernel working with packed (contiguous) tensors with
// multi-ABD support
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
// weight_grad[G][K][C][filter]
// Computes gradient with respect to weights
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (input) tensors
index_t NumBExtra, // Number of extra B (output gradient) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* __restrict__ p_out_grad,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void
naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* const* __restrict__ p_out_grads,
const DDataType* const* __restrict__ p_ds,
const index_t* const* __restrict__ p_d_strides,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gn[wi]);
out_op(out_val, out_gn_k[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_n_c,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c,
wi);
// Handle output gradient element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_n_k,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k,
wo);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
}
}
@@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
// Pointers at current spatial height
const InDataType* input_at_h = input_at_n_c + hi * in_stride_h;
const OutDataType* output_grad_at_h =
output_grad_at_n_k + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
out_op(out_val, out_gn_kh[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
hi * in_stride_h,
wi);
// Handle output gradient element-wise operation with extra B
// tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_h,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k +
ho * out_stride_h,
wo);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
@@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(wei_val,
wei_op,
acc,
p_ds,
p_d_strides,
g,
k,
c,
y * p_d_strides[0][3] +
x * p_d_strides[0][4]);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
x] = wei_val;
}
@@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
// Pointers at current spatial depth
const InDataType* input_at_d = input_at_n_c + di * in_stride_d;
const OutDataType* output_grad_at_d =
output_grad_at_n_k + do_idx * out_stride_d;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
// Pointers at current spatial depth and height
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
const OutDataType* output_grad_at_d_h =
output_grad_at_d + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
out_op(out_val, out_gn_kdh[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_d_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
di * in_stride_d + hi * in_stride_h,
wi);
// Handle output gradient element-wise operation with extra
// B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_d_h,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k +
do_idx * out_stride_d + ho * out_stride_h,
wo);
acc += type_convert<float>(out_val) *
type_convert<float>(in_val);
}
@@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
wei_val,
wei_op,
acc,
p_ds,
p_d_strides,
g,
k,
c,
z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y + x] = wei_val;
}
}
}
// GPU reference backward weight convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference backward weight convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -270,15 +370,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_weight(const TIn* p_in,
TWei* p_wei_grad,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility
void naive_conv_bwd_weight_multi_abd(
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
TWei* p_wei_grad,
const std::array<const TOut*, NumBElementwise + 1>& p_outs,
const std::array<const TD*, NumDElementwise>& p_ds,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in,
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
std::vector<SimpleDeviceMem> in_packed_bufs;
in_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
}
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
std::vector<SimpleDeviceMem> out_grad_packed_bufs;
out_grad_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut));
}
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
}
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
std::array<TOut*, NumBElementwise + 1> p_out_grads_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_out_grads_packed[i] = static_cast<TOut*>(out_grad_packed_bufs[i].GetDeviceBuffer());
}
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
@@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in,
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
}
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_outs[i],
p_out_grads_packed[i],
d_out_lengths,
d_out_strides,
dim_count,
out_total);
}
// Prepare D tensor stride arrays on device
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
TOut** d_out_grads_ptrs = static_cast<TOut**>(out_grads_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
p_ins_packed.data(),
(NumAElementwise + 1) * sizeof(TIn*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs,
p_out_grads_packed.data(),
(NumBElementwise + 1) * sizeof(TOut*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
if(ndim == 1)
{
naive_conv_bwd_weight_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_bwd_weight_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
}
else // 3D
{
naive_conv_bwd_weight_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_bwd_weight - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void
naive_conv_bwd_weight(const TIn* p_in,
TWei* p_wei_grad,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TIn*, 1> p_ins = {p_in};
std::array<const TOut*, 1> p_outs = {p_out};
std::array<const TWei*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
p_wei_grad,
p_outs,
p_ds,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -10,48 +10,56 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized convolution kernel working with packed (contiguous) tensors
// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (input) tensors
index_t NumBExtra, // Number of extra B (weight) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void naive_conv_fwd_packed_multi_abd(
const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra)
const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra)
const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers
const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gc = in_g + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gc[wi]);
wei_op(wei_val, wei_gkc[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_c,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_c,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c,
x);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
}
}
@@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
// Pointers at current spatial height and filter Y position
const InDataType* input_at_h = input_at_c + hi * in_stride_h;
const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
wei_op(wei_val, wei_gkcy[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
hi * in_stride_h,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_y,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
y * wei_stride_y,
x);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
@@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(out_val,
out_op,
acc,
p_ds,
p_d_strides,
g,
n,
k,
ho * p_d_strides[0][3] +
wo * p_d_strides[0][4]);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
out_val;
}
@@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t z = 0; z < Z; ++z)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
// Pointers at current spatial depth
const InDataType* input_at_d = input_at_c + di * in_stride_d;
const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
// Pointers at current spatial depth and height
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
wei_op(wei_val, wei_gkczy[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_d_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
di * in_stride_d + hi * in_stride_h,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_z_y,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
z * wei_stride_z + y * wei_stride_y,
x);
acc += type_convert<float>(in_val) *
type_convert<float>(wei_val);
}
@@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
out_val,
out_op,
acc,
p_ds,
p_d_strides,
g,
n,
k,
do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
ho * out_stride_h + wo] = out_val;
}
}
}
// GPU reference convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -266,15 +357,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility
void naive_conv_fwd_multi_abd(
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
const std::array<const TD*, NumDElementwise>& p_ds,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in,
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
// Allocate packed buffers for all A and B tensors
// Use separate allocations to avoid copy assignment issues with RAII wrapper
std::vector<SimpleDeviceMem> in_packed_bufs;
in_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
}
std::vector<SimpleDeviceMem> wei_packed_bufs;
wei_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
}
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
// Get packed buffer pointers
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
}
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
}
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
@@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in,
// Pack input and weight tensors to contiguous layout
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
// Pack all A tensors
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
}
// Pack all B tensors
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
}
// Prepare D tensor stride arrays on device
// NOTE: D tensors are NOT packed - they are used directly with their original strides
// to support broadcasting (e.g., BiasGK layout with zero strides)
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
// Allocate and copy strides to device
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
p_ins_packed.data(),
(NumAElementwise + 1) * sizeof(TIn*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
p_weis_packed.data(),
(NumBElementwise + 1) * sizeof(TWei*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
// D tensors use original pointers (not packed) to support broadcasting
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in,
if(ndim == 1)
{
naive_conv_fwd_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_fwd_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else // 3D
{
naive_conv_fwd_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_fwd - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TIn*, 1> p_ins = {p_in};
std::array<const TWei*, 1> p_weis = {p_wei};
std::array<const TOut*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
p_weis,
p_ds,
p_out,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -22,9 +22,39 @@ struct SimpleDeviceMem
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
}
// Delete copy operations (resource should not be copied)
SimpleDeviceMem(const SimpleDeviceMem&) = delete;
SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete;
// Define move operations
SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_)
{
other.p_mem_ = nullptr;
}
SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept
{
if(this != &other)
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
p_mem_ = other.p_mem_;
other.p_mem_ = nullptr;
}
return *this;
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
~SimpleDeviceMem()
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
}
void* p_mem_;
};
@@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src,
}
}
namespace detail {
// Helper for parameter pack expansion (D tensors)
template <typename ResultType, typename Op, typename DataType, std::size_t... Is>
__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result,
Op&& element_op,
const DataType* const* tensor_ptrs,
long_index_t element_offset,
std::index_sequence<Is...>)
{
element_op(result, tensor_ptrs[Is][element_offset]...);
}
// Generic helper for A and B tensors (works in all directions)
template <index_t NumExtraTensors, typename DataType, typename ResultType, typename Op>
__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result,
Op&& element_op,
const DataType* primary_ptr,
const DataType* const* extra_ptrs,
long_index_t extra_base_offset,
long_index_t element_offset)
{
const DataType* tensor_ptrs[NumExtraTensors + 1];
tensor_ptrs[0] = primary_ptr;
static_for<1, NumExtraTensors + 1, 1>{}(
[&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; });
apply_multi_tensor_impl(result,
element_op,
tensor_ptrs,
element_offset,
std::make_index_sequence<NumExtraTensors + 1>{});
}
// Helper for parameter pack expansion (D tensors)
template <typename OutDataType, typename Op, std::size_t... Is>
__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out,
Op&& element_op,
float computed_value,
const float* d_values,
std::index_sequence<Is...>)
{
float temp_out;
element_op(temp_out, computed_value, d_values[Is]...);
result_out = type_convert<OutDataType>(temp_out);
}
// Specialized helper for D tensors with stride calculations and float conversion
template <index_t NumDTensors, typename DDataType, typename OutDataType, typename Op>
__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out,
Op&& element_op,
float computed_value,
const DDataType* const* p_ds,
const index_t* const* p_d_strides,
index_t g,
index_t n,
index_t c_or_k,
long_index_t spatial_linear_index)
{
if constexpr(NumDTensors == 0)
{
element_op(result_out, computed_value);
}
else
{
float d_values[NumDTensors];
// Compute all D tensor indices and convert to float
static_for<0, NumDTensors, 1>{}([&](auto i) {
const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] +
c_or_k * p_d_strides[i][2] + spatial_linear_index;
d_values[i] = type_convert<float>(p_ds[i][d_idx]);
});
apply_d_tensor_impl(result_out,
element_op,
computed_value,
d_values,
std::make_index_sequence<NumDTensors>{});
}
}
} // namespace detail
} // namespace ref
} // namespace ck

View File

@@ -376,7 +376,7 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
// clang-format on
>;
#if defined(__gfx950__)
#if defined(CK_USE_GFX950)
constexpr auto _k_per_block = 32;
#else
constexpr auto _k_per_block = 16;

View File

@@ -147,7 +147,7 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format on
>;
#if defined(__gfx950__)
#if defined(CK_USE_GFX950)
constexpr auto _k_per_block = 32;
#else
constexpr auto _k_per_block = 16;

View File

@@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance

View File

@@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
@@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance

View File

@@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
@@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
} // namespace instance

View File

@@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,

View File

@@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,

View File

@@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
@@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,

View File

@@ -7,12 +7,17 @@
#include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp"
#include "grouped_convolution_signatures.hpp"
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
#include "ck_tile/builder/testing/filter_extent.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/builder/conv_builder.hpp"
// Temporary disable builder validate since we don't have deduced rtol, atol support
#define ENABLE_BUILDER_VALIDATE 0
namespace ck_tile::builder::profiling {
namespace ckb = ck_tile::builder;
@@ -113,25 +118,66 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
auto reference = ckt::alloc_outputs(args);
using ReferenceInstance =
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
auto ref_conv = ReferenceInstance{};
ckt::run(ref_conv, args, inputs, reference.get());
auto ref_conv = ReferenceInstance{};
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
#if ENABLE_BUILDER_VALIDATE == 0
using DataType =
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
float,
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP16,
ck_tile::half_t,
ck_tile::bfloat16_t>>;
const auto conv_param = args.to_ck_tile_conv_param();
const std::size_t output_bytes_num = conv_param.template GetOutputByte<DataType>();
std::vector<DataType> out(output_bytes_num / sizeof(DataType));
std::vector<DataType> ref(output_bytes_num / sizeof(DataType));
HIP_CHECK_ERROR(
hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost));
const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(),
conv_param.filter_spatial_lengths_.cend(),
1,
std::multiplies<ck_tile::index_t>()) *
conv_param.C_;
float max_accumulated_value = *std::max_element(ref.begin(), ref.end());
const auto rtol = ck_tile::get_relative_threshold<DataType, DataType, float>(GemmK);
const auto atol =
ck_tile::get_absolute_threshold<DataType, DataType, float>(max_accumulated_value, GemmK);
#endif
[[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) {
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
if(is_supported)
{
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name
<< std::endl;
#if ENABLE_BUILDER_VALIDATE
const auto errors = ckt::validate(args, outputs, reference.get()).get_errors();
for(const auto& error : errors)
{
valid = false;
std::cout << "Number of incorrect values: " << error.wrong_elements
<< " Is all zero:" << error.is_all_zero() << std::endl;
<< " Is all zero:" << error.is_all_zero()
<< " max err: " << error.max_error << std::endl;
}
best_avg_time = std::min(best_avg_time, avg_time);
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms,";
#else
HIP_CHECK_ERROR(
hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost));
valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol);
#endif
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
}
else
{
std::cout << " " << op_name << std::endl;
}
std::cout << " " << op_name << std::endl;
};
if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD)

View File

@@ -6,7 +6,7 @@
#include <tuple>
#include "../../experimental/builder/test/impl/conv_signature_types.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
namespace ck_tile::builder::profiling {

View File

@@ -17,6 +17,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
namespace ck {
namespace profiler {
@@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification,
out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
if(do_verification)
// profile device Conv instances
bool pass = true;
if(do_verification == 1)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
InDataType,
@@ -154,6 +158,27 @@ bool profile_conv_bwd_data_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
// GPU reference (compute once, compare in kernel loop)
Tensor<InDataType> gpu_ref_input(in_g_n_c_wis_desc);
if(do_verification == 2)
{
DeviceMem gpu_ref_in_dev(sizeof(InDataType) *
input_device_result.mDesc.GetElementSpaceSize());
gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization
ck::ref::naive_conv_bwd_data<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(gpu_ref_in_dev.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
in_element_op,
wei_element_op,
out_element_op);
hip_check_error(hipDeviceSynchronize());
gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data());
}
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData<NDimSpatial,
InLayout,
WeiLayout,
@@ -176,8 +201,6 @@ bool profile_conv_bwd_data_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device Conv instances
bool pass = true;
for(auto& op_ptr : op_ptrs)
{
@@ -235,7 +258,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
if(do_verification == 1)
{
in_device_buf.FromDevice(input_device_result.mData.data());
@@ -255,6 +278,31 @@ bool profile_conv_bwd_data_impl(int do_verification,
show_data_nhwc_layout(input_host_result);
std::cout << std::endl;
std::cout << "out_device: ";
show_data_nhwc_layout(input_device_result);
std::cout << std::endl;
}
}
else if(do_verification == 2)
{
in_device_buf.FromDevice(input_device_result.mData.data());
pass = pass & ck::utils::check_err(input_device_result, gpu_ref_input);
if(do_log)
{
std::cout << "in : ";
show_data_nhwc_layout(output);
std::cout << std::endl;
std::cout << "wei: ";
show_data_nhwc_layout(weight);
std::cout << std::endl;
std::cout << "out_gpu_ref : ";
show_data_nhwc_layout(gpu_ref_input);
std::cout << std::endl;
std::cout << "out_device: ";
show_data_nhwc_layout(input_device_result);
std::cout << std::endl;

View File

@@ -21,6 +21,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -107,8 +108,11 @@ bool profile_conv_fwd_impl(int do_verification,
in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
// profile device op instances
bool pass = true;
// run reference op
if(do_verification)
if(do_verification == 1)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
@@ -135,6 +139,24 @@ bool profile_conv_fwd_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
// GPU reference (compute once, compare in kernel loop)
Tensor<OutDataType> gpu_ref_output(out_g_n_k_wos_desc);
if(do_verification == 2)
{
DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(gpu_ref_out_dev.GetDeviceBuffer()),
conv_param,
in_element_op,
wei_element_op,
out_element_op);
hip_check_error(hipDeviceSynchronize());
gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data());
}
using DeviceOp = ck::tensor_operation::device::DeviceConvFwd<NDimSpatial,
InLayout,
@@ -158,8 +180,6 @@ bool profile_conv_fwd_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
bool pass = true;
for(auto& op_ptr : op_ptrs)
{
@@ -217,7 +237,7 @@ bool profile_conv_fwd_impl(int do_verification,
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
if(do_verification == 1)
{
out_device_buf.FromDevice(device_output.mData.data());
@@ -233,6 +253,23 @@ bool profile_conv_fwd_impl(int do_verification,
<< std::endl;
}
}
else if(do_verification == 2)
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, gpu_ref_output);
if(do_log)
{
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
<< std::endl;
}
}
}
else
{

View File

@@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
// Calculate number of accumulations accounting for split_k
const int num_accums =
static_cast<int>(output.GetElementSize() / conv_param.K_ / split_k_value);
// Additional tolerance for split_k accumulation if needed
int total_accums = num_accums;
if(split_k_value > 1)
{
total_accums = std::max(num_accums, static_cast<int>(split_k_value));
}
// Perform GPU verification (max value computed internally on GPU)
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_value;
// Get maximum accumulated value from reference
const std::size_t tensor_size =
weight_device_result.mDesc.GetElementSpaceSize();
max_accumulated_value =
gpu_reduce_max<WeiDataType>(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size);
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
auto atol =
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k,
num_accums / num_accums_split_k);
// Calculate error due to split_k accumulation
auto rtol_split_k =
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
num_accums_split_k);
auto atol_split_k =
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
max_accumulated_value, num_accums_split_k);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
// Perform GPU verification
auto gpu_result =
ck::profiler::gpu_verify<WeiDataType, ComputeType, AccDataType>(
wei_device_buf.GetDeviceBuffer(),
gpu_ref_wei_buf.GetDeviceBuffer(),
total_accums,
tensor_size);
ck::profiler::gpu_verify<WeiDataType>(wei_device_buf.GetDeviceBuffer(),
gpu_ref_wei_buf.GetDeviceBuffer(),
rtol,
atol,
tensor_size);
if(!gpu_result)
{

View File

@@ -21,6 +21,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
bias_device_buf.ToDevice(bias.mData.data());
// run reference op
if(do_verification)
if(do_verification == 1)
{
// CPU reference
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
@@ -190,6 +192,75 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
else if(do_verification == 2)
{
// GPU reference
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
d_lengths_vec[0] = conv_param.G_;
d_lengths_vec[1] = conv_param.N_;
d_lengths_vec[2] = conv_param.K_;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
}
if constexpr(BiasGK)
{
// For GK bias layout: G*K, zero strides for N and spatial dimensions
d_strides_vec[0] = K;
d_strides_vec[1] = 0;
d_strides_vec[2] = 1;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_strides_vec[3 + i] = 0;
}
}
else
{
// Full GNKHW layout - same as output
ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin());
}
std::array<const OutDataType*, 1> d_ptrs = {
reinterpret_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer())};
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
std::array<const InDataType*, 1> in_ptrs = {
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
std::array<const WeiDataType*, 1> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
ck::ref::naive_conv_fwd_multi_abd<0,
0,
1,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
OutDataType>( // Explicitly specify TD = OutDataType
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op);
HIP_CHECK_ERROR(hipDeviceSynchronize());
out_device_buf.FromDevice(host_output.mData.data());
}
std::string best_op_name;
float best_avg_time = 0;

View File

@@ -22,6 +22,7 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace profiler {
@@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl(
wei_device_buf.ToDevice(weight.mData.data());
d_device_buf.ToDevice(d_tensor.mData.data());
if(do_verification)
if(do_verification == 1)
{
// CPU reference
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
InDataType,
@@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl(
host_output(idx) = ck::type_convert<OutDataType>(out_val);
});
}
else if(do_verification == 2)
{
// GPU reference
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
d_lengths_vec[0] = conv_param.G_;
d_lengths_vec[1] = conv_param.N_;
d_lengths_vec[2] = conv_param.K_;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
}
// D tensor has same layout as output
ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin());
std::array<const DDataType*, 1> d_ptrs = {
reinterpret_cast<const DDataType*>(d_device_buf.GetDeviceBuffer())};
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
std::array<const InDataType*, 1> in_ptrs = {
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
std::array<const WeiDataType*, 1> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
ck::ref::naive_conv_fwd_multi_abd<0,
0,
1,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DDataType>( // Explicitly specify D tensor type
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param,
d_lengths,
d_strides,
InElementOp{},
WeiElementOp{},
bilinear_op);
HIP_CHECK_ERROR(hipDeviceSynchronize());
out_device_buf.FromDevice(host_output.mData.data());
}
std::string best_op_name;
float best_avg_time = 0;

View File

@@ -7,6 +7,7 @@
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "profiler/common.hpp"
@@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
std::cout << "scale_out: " << scale_out << std::endl;
// run reference op
if(do_verification)
if(do_verification == 1)
{
std::cout << "\nVerifying algorithm against reference convolution..." << std::endl;
@@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
}
});
}
else if(do_verification == 2)
{
// GPU reference
// WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference
// Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md)
if constexpr(std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::Scale> &&
std::is_same_v<OutDataType, int8_t>)
{
// Compute conv to CShuffleDataType (float), then post-process on CPU
DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize());
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<CShuffleDataType*>(gpu_ref_c_dev.GetDeviceBuffer()),
conv_param,
in_element_op,
wei_element_op,
PassThrough{});
ck::hip_check_error(hipDeviceSynchronize());
Tensor<CShuffleDataType> gpu_c(out_g_n_k_wos_desc);
gpu_ref_c_dev.FromDevice(gpu_c.mData.data());
// Post-process on CPU to match CPU reference behavior
host_output.ForEach([&](auto&, auto idx) {
const auto conv_shuffle = ck::type_convert<CShuffleDataType>(gpu_c(idx));
const auto conv_val = ck::type_convert<OutDataType>(conv_shuffle);
out_element_op(host_output(idx), conv_val);
});
}
else
{
// Normal path for non-int8 or non-Scale cases
DeviceMem gpu_ref_out_dev(sizeof(OutDataType) *
device_output.mDesc.GetElementSpaceSize());
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(gpu_ref_out_dev.GetDeviceBuffer()),
conv_param,
in_element_op,
wei_element_op,
out_element_op);
ck::hip_check_error(hipDeviceSynchronize());
gpu_ref_out_dev.FromDevice(host_output.mData.data());
}
}
std::string best_op_name;
float best_avg_time = 0;
@@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
if(do_verification == 1)
{
out_device_buf.FromDevice(device_output.mData.data());
@@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
<< std::endl;
}
}
else if(do_verification == 2)
{
out_device_buf.FromDevice(device_output.mData.data());
pass =
pass & ck::utils::check_err(device_output,
host_output,
"Error: Device and GPU ref results do not match!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
if(do_log)
{
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "gpu_ref_output : ", host_output.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
<< std::endl;
}
}
}
else
{

View File

@@ -6,7 +6,7 @@
#include <initializer_list>
#include <cstdlib>
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "profiler/grouped_convolution_forward_tile_algs.hpp"

View File

@@ -0,0 +1,263 @@
# Build Trace Analysis
Simple to use, fast python tools for analyzing Clang `-ftime-trace` build performance data.
## Overview
We're kicking off a systematic effort to dramatically reduce CK and CK-Tile build times, [#3575](https://github.com/ROCm/composable_kernel/issues/3575). A key part of this work is improving our C++ metaprogramming to reduce the burden on the compiler.
In order to prioritize work and measure our progress, we need data on template instantiation. For single files, Clang's `-ftime-trace` build performance data is easy to analyze with the Perfetto UI. The problem we are solving here is how to analyze instantiation data across thousands of compilation units.
The python code in this directory provides helper functions to quickly load JSON files into pandas DataFrames that can be used for analysis in Jupyter notebooks.
## Directory Structure
```
script/analyze_build/
├── trace_analysis/ # Core library
│ ├── __init__.py # Main exports
│ ├── parse_file.py # Fast parsing of JSON trace files
│ ├── template_analysis.py # Template instantiation analysis
│ ├── template_parser.py # Template name parsing utilities
│ └── phase_breakdown.py # Compilation phase breakdown
├── notebooks/ # Jupyter notebooks for analysis
│ └── file_analysis_example.ipynb # Template analysis example
├── requirements.txt # Python dependencies
└── README.md # This file
```
## Python Requirements
See `requirements.txt` for the complete list of dependencies:
* **pandas** - DataFrame manipulation and analysis
* **orjson** - Fast JSON parsing for trace files
* **plotly** - Interactive visualizations (sunburst, treemap)
* **nbformat** - Jupyter notebook format support
* **ipykernel** - Kernel for running notebooks in VSCode/Jupyter
* **kaleido** - Static image export from Plotly charts
* **jupyter** - Full Jupyter environment
## Quick Start
### Setup
1. Create a virtual environment (recommended):
```bash
cd script/analyze_build
python3 -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Install VSCode extensions if you want to run notebooks in VSCode:
* Jupyter
* Data Wrangler (interact with Pandas DataFrames)
### Analyzing a Single File
Use the `parse_file` function to load a `-ftime-trace` JSON file into a Pandas DataFrame:
```python
from trace_analysis import parse_file
# Parse the trace file
df = parse_file('path/to/trace.json')
# View basic info
print(f"Total events: {len(df)}")
print(df.columns)
# Analyze duration statistics
print(df['dur'].describe())
```
### Extracting Compilation Metadata
Get high-level metadata about the compilation:
```python
from trace_analysis import get_metadata
# Extract metadata from trace file
metadata = get_metadata('trace.json')
print(f"Source file: {metadata['source_file']}")
print(f"Compilation time: {metadata['total_wall_time_s']:.2f}s")
print(f"Started: {metadata['wall_start_datetime']}")
print(f"Ended: {metadata['wall_end_datetime']}")
```
The metadata includes:
- `source_file`: Main .cpp/.c file being compiled
- `time_granularity`: Time unit used ("microseconds")
- `beginning_of_time`: Epoch timestamp in microseconds
- `wall_start_time`: Wall clock start (microseconds since epoch)
- `wall_end_time`: Wall clock end (microseconds since epoch)
- `wall_start_datetime`: Human-readable start time
- `wall_end_datetime`: Human-readable end time
- `total_wall_time_us`: Total compilation time in microseconds
- `total_wall_time_s`: Total compilation time in seconds
### Template Instantiation Analysis
The module includes specialized functions for analyzing C++ template instantiation costs:
```python
from trace_analysis import (
parse_file,
get_template_instantiation_events,
get_phase_breakdown,
)
df = parse_file('trace.json')
# Get all template instantiation events with parsed template information
template_events = get_template_instantiation_events(df)
# The returned DataFrame includes parsed columns:
# - namespace: Top-level namespace (e.g., 'std', 'ck')
# - template_name: Template name without parameters
# - full_qualified_name: Full namespace::template_name
# - param_count: Number of template parameters
# - is_ck_type: Boolean indicating CK library types
# - is_nested: Boolean indicating nested templates
# Find slowest template instantiations
top_templates = template_events.nlargest(20, 'dur')
print(top_templates[['template_name', 'namespace', 'param_count', 'dur']])
# Analyze by namespace
namespace_summary = template_events.groupby('namespace').agg({
'dur': ['count', 'sum', 'mean']
})
print(namespace_summary)
```
### Compilation Phase Breakdown
Analyze how compilation time is distributed across different phases:
```python
from trace_analysis import get_phase_breakdown, PhaseBreakdown
df = parse_file('trace.json')
# Get hierarchical phase breakdown
breakdown = get_phase_breakdown(df)
# Display in Jupyter (automatic rich HTML display)
display(breakdown)
# Print text representation
print(breakdown)
# Access the underlying DataFrame
print(breakdown.df)
# Convert to plotly format for visualization
import plotly.express as px
data = breakdown.to_plotly()
fig = px.sunburst(**data)
fig.show()
```
The `PhaseBreakdown` class provides:
- Hierarchical breakdown of compilation phases
- Automatic calculation of "Other" residual time at each level
- Validation that children don't exceed parent durations
- Multiple output formats (text, DataFrame, Plotly)
## DataFrame Schema
The parsed DataFrame contains the following columns from the `-ftime-trace` format:
- `name`: Event name (function, template instantiation, etc.)
- `ph`: Phase character ('X' for complete, 'B' for begin, 'E' for end, 'i' for instant)
- `ts`: Timestamp in microseconds
- `dur`: Duration in microseconds (for complete events)
- `pid`: Process ID
- `tid`: Thread ID
- `arg_*`: Flattened arguments from the event's `args` field
### Template Event Columns
When using `get_template_instantiation_events()`, additional parsed columns are included:
- `namespace`: Top-level namespace extracted from the template name
- `template_name`: Template name without namespace or parameters
- `full_qualified_name`: Complete namespace::template_name
- `param_count`: Number of template parameters
- `is_ck_type`: Boolean flag for CK library types (namespace starts with 'ck')
- `is_nested`: Boolean flag indicating nested template instantiations
## Use in Jupyter Notebooks
The module is designed to work seamlessly in Jupyter notebooks. See `notebooks/file_analysis_example.ipynb` for a complete example workflow that demonstrates:
- Loading and parsing trace files
- Extracting compilation metadata
- Analyzing phase breakdown with visualizations
- Template instantiation analysis with parsed columns
- Filtering and grouping by namespace
- Identifying CK-specific template costs
To use in a notebook:
```python
import sys
from pathlib import Path
# Add trace_analysis to path
sys.path.insert(0, str(Path.cwd().parent))
from trace_analysis import (
parse_file,
get_metadata,
get_template_instantiation_events,
get_phase_breakdown,
)
# Load and analyze
df = parse_file('path/to/trace.json')
breakdown = get_phase_breakdown(df)
templates = get_template_instantiation_events(df)
# Visualize
import plotly.express as px
fig = px.sunburst(**breakdown.to_plotly())
fig.show()
```
## API Reference
### Core Functions
- `parse_file(filepath)`: Parse a `-ftime-trace` JSON file into a pandas DataFrame
- `get_metadata(filepath_or_df)`: Extract compilation metadata from trace file or DataFrame
### Template Analysis
- `get_template_instantiation_events(df)`: Filter to template instantiation events with parsed template information
### Phase Breakdown
- `get_phase_breakdown(df)`: Generate hierarchical compilation phase breakdown
- `PhaseBreakdown`: Class representing phase breakdown with multiple output formats
## Contributing
This is an experimental project for analyzing and improving C++ metaprogramming build times. Contributions are welcome! When adding new analysis functions:
1. Add the function to the appropriate module in `trace_analysis/`
2. Export it in `__init__.py`
3. Update this README with usage examples
4. Consider adding a notebook example if the feature is substantial
## License
Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
SPDX-License-Identifier: MIT

Some files were not shown because too many files have changed in this diff Show More