mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge remote-tracking branch 'origin/jograner/bwd-weight-splitk-autodeduce' into features/grouped-conv-perf-uplift
This commit is contained in:
@@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
|
||||
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
|
||||
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
|
||||
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
|
||||
|
||||
@@ -259,6 +259,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL FP8 gemms on gfx950")
|
||||
add_definitions(-DCK_USE_GFX950)
|
||||
set(CK_USE_GFX950 "ON")
|
||||
endif()
|
||||
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
|
||||
101
Dockerfile.manylinux
Normal file
101
Dockerfile.manylinux
Normal file
@@ -0,0 +1,101 @@
|
||||
FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
USER root
|
||||
|
||||
# Add rocm repository
|
||||
RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \
|
||||
dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \
|
||||
dnf update -y && \
|
||||
dnf install python3-setuptools python3-wheel -y && \
|
||||
dnf install rocm-dev -y
|
||||
|
||||
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
|
||||
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
|
||||
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
|
||||
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
|
||||
ENV CK_SCCACHE=$CK_SCCACHE
|
||||
RUN if [ "$CK_SCCACHE" != "" ]; then \
|
||||
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
|
||||
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
|
||||
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
|
||||
fi
|
||||
|
||||
# Install dependencies
|
||||
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
|
||||
cmake \
|
||||
clang-tools-extra \
|
||||
gcc-c++ \
|
||||
libstdc++ \
|
||||
libstdc++-devel \
|
||||
libstdc++-static \
|
||||
git \
|
||||
hip-rocclr \
|
||||
jq \
|
||||
mpich \
|
||||
net-tools \
|
||||
pkg-config \
|
||||
redis \
|
||||
sshpass \
|
||||
stunnel \
|
||||
vim \
|
||||
nano \
|
||||
zip \
|
||||
openssh-server \
|
||||
kmod && \
|
||||
dnf clean all && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf amdgpu-install* && \
|
||||
#Install latest ccache
|
||||
git clone https://github.com/ccache/ccache.git && \
|
||||
cd ccache && mkdir build && cd build && cmake .. && make install && \
|
||||
#Install ClangBuildAnalyzer
|
||||
git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \
|
||||
cd ClangBuildAnalyzer/ && \
|
||||
make -f projects/make/Makefile && \
|
||||
cd / && \
|
||||
#Install latest cppcheck
|
||||
git clone https://github.com/danmar/cppcheck.git && \
|
||||
cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \
|
||||
cd / && \
|
||||
# Install packages for processing the performance results
|
||||
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
|
||||
# Add render group
|
||||
groupadd -f render && \
|
||||
# Install the new rocm-cmake version
|
||||
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
|
||||
cd rocm-cmake && mkdir build && cd build && \
|
||||
cmake .. && cmake --build . && cmake --build . --target install
|
||||
|
||||
WORKDIR /
|
||||
# Add alternative compilers, if necessary
|
||||
ENV compiler_version=$compiler_version
|
||||
ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'" && \
|
||||
sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -39,10 +39,10 @@ def sendFailureNotifications() {
|
||||
// Error patterns to scan build logs for specific failure types and send detailed notifications.
|
||||
def failurePatterns = [
|
||||
[pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"],
|
||||
[pattern: /docker login failed/, description: "Docker login failed"],
|
||||
[pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"],
|
||||
[pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"],
|
||||
[pattern: /cat: .* No such file or directory/, description: "GPU not found"],
|
||||
[pattern: /GPU not found/, description: "GPU not found"],
|
||||
[pattern: /(.*)GPU not found(.*)/, description: "GPU not found"],
|
||||
[pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"]
|
||||
]
|
||||
|
||||
|
||||
@@ -19,22 +19,22 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
PassThrough, PassThrough, PassThrough, GemmSpec,
|
||||
256,
|
||||
128, 256, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
2, 8,
|
||||
S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 8, 8, 1,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
1, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1, 64, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
using GemmConfig = GemmConfigQuantDecodeInterwave<T>;
|
||||
|
||||
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
|
||||
// template <typename T>
|
||||
|
||||
@@ -93,6 +93,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigQuantDecodeInterwave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -229,6 +250,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
// static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
|
||||
@@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(1.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
@@ -659,6 +659,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(2.0f)}(*aq_tensor_ptr);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
|
||||
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
|
||||
{
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
// Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...)
|
||||
for(ck_tile::index_t row = 0;
|
||||
row < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(0));
|
||||
++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0;
|
||||
col < static_cast<ck_tile::index_t>(aq_tensor_ptr->get_length(1));
|
||||
++col)
|
||||
{
|
||||
(*aq_tensor_ptr)(row, col) = static_cast<AQDataType>(col + 1);
|
||||
}
|
||||
}
|
||||
// std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl;
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*aq_tensor_ptr);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
|
||||
@@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm()
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
default: throw "Unknown PipelineVersion";
|
||||
@@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
@@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -91,6 +91,13 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
@@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -7,26 +7,25 @@
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/filter_extent.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/validation.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
/// This file implements common functionality for invoking/testing grouped
|
||||
/// forward convolutions created through the CK Builder API. The main item
|
||||
/// of it is the ConvArgs structure - which contains a complete description
|
||||
/// of it is the Args structure - which contains a complete description
|
||||
/// of a convolution operation.
|
||||
///
|
||||
/// It is not intended that this file contains implementation details for
|
||||
/// actually launching a convolution operation. As this can be done
|
||||
/// through different APIs depending on the kernel (CK, CK Tile, or a
|
||||
/// reference implementation), the code dealing with that is split out
|
||||
/// into a separate header for each implementation.
|
||||
/// into a separate header for each implementation. Nor does this file
|
||||
/// deal with details for defining the data types (`Inputs` and `Outputs`)
|
||||
/// for different conv directions, that is also split out into separate
|
||||
/// headers to keep this one small.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
@@ -56,7 +55,7 @@ struct ConvTensorLengths
|
||||
///
|
||||
/// @see Args
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE>
|
||||
struct Args<SIGNATURE>
|
||||
{
|
||||
constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -204,53 +203,4 @@ struct Args<SIGNATURE>
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Inputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* weight;
|
||||
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* output;
|
||||
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see alloc_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/conv/args.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/error.hpp"
|
||||
|
||||
/// This file deals with the backward weight-specific details of running grouped
|
||||
/// convolution backwards weight operations. It mainly defines the data
|
||||
/// structures (`Input` and `Output`), initialization, and validation. Note
|
||||
/// that for this operation specifically, many of the operations are
|
||||
/// implemented automatically via testing_reflect.hpp.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief `Inputs` specialization for backwards weight convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* output;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for backwards weight convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* weight;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("weight", args.make_weight_descriptor(), &Outputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for backwards convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see init_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,276 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// bwd grouped convolution operations in old CK. The main item is the
|
||||
/// `run()` function, which is the main implementation used to invoke
|
||||
/// CK grouped forward convolution kernels.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvBwdWeightInstance = requires(Conv& conv,
|
||||
const Types::InDataType* p_a,
|
||||
Types::WeiDataType* p_b,
|
||||
const Types::OutDataType* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde,
|
||||
ck::index_t split_k) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
// A lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// B lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// E lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// strides/dilations/pads
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
// element-wise operations.
|
||||
elementwise_a,
|
||||
elementwise_b,
|
||||
elementwise_cde,
|
||||
split_k)
|
||||
};
|
||||
};
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
|
||||
/// invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv,
|
||||
const Types::InDataType* p_a,
|
||||
Types::WeiDataType* p_b,
|
||||
const Types::OutDataType* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde,
|
||||
ck::index_t split_k) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
// TODO: Actually support multiple d
|
||||
{},
|
||||
// A lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// B lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// E lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// TODO: Multiple D lengths/strides
|
||||
{},
|
||||
{},
|
||||
// strides/dilations/pads
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
// element-wise operations.
|
||||
elementwise_a,
|
||||
elementwise_b,
|
||||
elementwise_cde,
|
||||
split_k)
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
|
||||
/// invoked like old CK.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvBwdWeightMultipleDInstance =
|
||||
detail::CkConvBwdWeightMultipleDInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for backward weight convolution and old CK.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkConvBwdWeightInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
|
||||
|
||||
const auto copy = [](const auto& src, auto& dst) {
|
||||
std::copy(src.begin(), src.end(), dst.begin());
|
||||
};
|
||||
|
||||
const auto to_ck_lengths = [&](const auto& src) {
|
||||
std::array<ck::index_t, spatial_dim + 3> result;
|
||||
copy(src, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto to_ck_extent = [&](const auto& extent) {
|
||||
std::array<ck::index_t, spatial_dim> result;
|
||||
copy(extent, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
const auto input_desc = args.make_input_descriptor();
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
|
||||
static_cast<Types::WeiDataType*>(outputs.weight),
|
||||
static_cast<const Types::OutDataType*>(inputs.output),
|
||||
to_ck_lengths(input_desc.get_lengths()),
|
||||
to_ck_lengths(input_desc.get_strides()),
|
||||
to_ck_lengths(weight_desc.get_lengths()),
|
||||
to_ck_lengths(weight_desc.get_strides()),
|
||||
to_ck_lengths(output_desc.get_lengths()),
|
||||
to_ck_lengths(output_desc.get_strides()),
|
||||
to_ck_extent(param.conv_filter_strides_),
|
||||
to_ck_extent(param.conv_filter_dilations_),
|
||||
to_ck_extent(param.input_left_pads_),
|
||||
to_ck_extent(param.input_right_pads_),
|
||||
args.a_elementwise_op,
|
||||
args.b_elementwise_op,
|
||||
args.cde_elementwise_op,
|
||||
args.k_batch);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
return RunResult::not_supported("invalid ck arguments");
|
||||
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
|
||||
}
|
||||
|
||||
/// @brief `run()` specialization for backward weight convolution and old CK.
|
||||
///
|
||||
/// This overload is specialized for Multiple-D.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
|
||||
|
||||
const auto copy = [](const auto& src, auto& dst) {
|
||||
std::copy(src.begin(), src.end(), dst.begin());
|
||||
};
|
||||
|
||||
const auto to_ck_lengths = [&](const auto& src) {
|
||||
std::array<ck::index_t, spatial_dim + 3> result;
|
||||
copy(src, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto to_ck_extent = [&](const auto& extent) {
|
||||
std::array<ck::index_t, spatial_dim> result;
|
||||
copy(extent, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
const auto input_desc = args.make_input_descriptor();
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
|
||||
static_cast<Types::WeiDataType*>(outputs.weight),
|
||||
static_cast<const Types::OutDataType*>(inputs.output),
|
||||
{}, // TODO
|
||||
to_ck_lengths(input_desc.get_lengths()),
|
||||
to_ck_lengths(input_desc.get_strides()),
|
||||
to_ck_lengths(weight_desc.get_lengths()),
|
||||
to_ck_lengths(weight_desc.get_strides()),
|
||||
to_ck_lengths(output_desc.get_lengths()),
|
||||
to_ck_lengths(output_desc.get_strides()),
|
||||
{}, // TODO
|
||||
{}, // TODO
|
||||
to_ck_extent(param.conv_filter_strides_),
|
||||
to_ck_extent(param.conv_filter_dilations_),
|
||||
to_ck_extent(param.input_left_pads_),
|
||||
to_ck_extent(param.input_right_pads_),
|
||||
args.a_elementwise_op,
|
||||
args.b_elementwise_op,
|
||||
args.cde_elementwise_op,
|
||||
args.k_batch);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
return RunResult::not_supported("invalid ck arguments");
|
||||
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -3,9 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include <type_traits>
|
||||
@@ -28,9 +27,39 @@ namespace detail {
|
||||
/// namespace.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkTileConvInstance = requires(Conv&) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
{ Conv::BlockSize() };
|
||||
};
|
||||
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
const ck_tile::stream_config s_conf)
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
|
||||
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
|
||||
host_args(param, input, weight, {}, output, args.k_batch);
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel(
|
||||
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
|
||||
@@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance<Conv, SIGNATURE>;
|
||||
/// @brief `run()` specialization for forward convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
return detail::run(conv,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<const void*>(inputs.weight),
|
||||
static_cast<void*>(outputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs<> host_args(
|
||||
param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch);
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
{
|
||||
std::cout << "Not supported!";
|
||||
return std::make_tuple(false, 0.f);
|
||||
}
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return std::make_tuple(
|
||||
true,
|
||||
ck_tile::launch_kernel(
|
||||
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
/// @brief `run()` specialization for backwards weight convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
return detail::run(conv,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<void*>(outputs.weight),
|
||||
static_cast<const void*>(inputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/conv/args.hpp"
|
||||
|
||||
/// This file deals with the forward-specific details of running grouped
|
||||
/// convolution forward operations. It mainly defines the data structures
|
||||
/// (`Input` and `Output`), initialization, and validation. Note that
|
||||
/// for this operation specifically, many of the operations are implemented
|
||||
/// automatically via testing_reflect.hpp.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief `Inputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* weight;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* output;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see init_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -3,14 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations in old CK. The main item is the
|
||||
/// fwd grouped convolution operations in old CK. The main item is the
|
||||
/// `run()` function, which is the main implementation used to invoke
|
||||
/// CK grouped forward convolution kernels.
|
||||
|
||||
@@ -18,10 +18,9 @@ namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
@@ -29,18 +28,21 @@ template <typename Conv,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
concept CkConvFwdInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsForward<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
@@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like old CK.
|
||||
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be an "old CK" implementation - that is, whether we should
|
||||
@@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
|
||||
concept CkConvFwdInstance = detail::CkConvFwdInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and old CK.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
[[nodiscard]] RunResult run(CkConvFwdInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
@@ -126,6 +125,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
return RunResult::not_supported("ck fwd does not support k_batch != 1");
|
||||
|
||||
auto ck_args = conv.MakeArgument(inputs.input,
|
||||
inputs.weight,
|
||||
{},
|
||||
@@ -147,11 +149,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
args.cde_elementwise_op);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
{
|
||||
std::cout << "invalid argument" << std::endl;
|
||||
}
|
||||
return RunResult::not_supported("unsupported ck arguments");
|
||||
|
||||
return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf));
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations using the reference implementation.
|
||||
/// The main item is the `run()` function, which is the primary way to
|
||||
/// invoke the reference execution mechanism.
|
||||
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
|
||||
/// but its made specific to the reference implementation, which is
|
||||
/// invoked in a slightly different way.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be the reference implementation - that is, whether we should
|
||||
/// invoke it like the reference kernel. This is mainly used with `run()` to
|
||||
/// differentiate which implementation that should be invoked.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
concept RefConvInstance = requires(Conv& conv,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
ck::utils::conv::ConvParam param) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
{ conv.Run(input, weight, output, param) };
|
||||
};
|
||||
|
||||
/// @brief Generic `run` implementation for forward/backwards reference kernels.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform.
|
||||
///
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f for reference).
|
||||
/// @see run()
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult
|
||||
run(RefConvInstance<SIGNATURE, InDataType, WeiDataType, OutDataType> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output)
|
||||
{
|
||||
// We don't want to compute the output dims manually, just get
|
||||
// them via the existing infrastructure
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
// TODO: The reference convolution is currently missing a few features.
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv");
|
||||
|
||||
if(!args.make_weight_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv");
|
||||
|
||||
if(!args.make_output_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv");
|
||||
|
||||
conv.Run(input, weight, output, param);
|
||||
return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime.
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// forward implementation.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvFwdInstance =
|
||||
detail::RefConvInstance<Conv, SIGNATURE, const void*, const void*, void*> &&
|
||||
ConvDirectionIsForward<SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// forward implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> &&
|
||||
// TODO: Maybe we can unify this implementation for bwd/weight too?
|
||||
// for now, just concern outselves with reference and see when the
|
||||
// rest of the bwd/weight plumbing is there.
|
||||
ConvDirectionIsForward<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(RefConvFwdInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
return detail::run(conv, args, inputs.input, inputs.weight, outputs.output);
|
||||
}
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// backward weight implementation.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvBwdWeightInstance =
|
||||
detail::RefConvInstance<Conv, SIGNATURE, const void*, void*, const void*> &&
|
||||
ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// backward weight implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(RefConvBwdWeightInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
return detail::run(conv, args, inputs.input, outputs.weight, inputs.output);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations using the reference implementation.
|
||||
/// The main item is the `run()` function, which is the primary way to
|
||||
/// invoke the reference execution mechanism.
|
||||
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
|
||||
/// but its made specific to the reference implementation, which is
|
||||
/// invoked in a slightly different way.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be the reference implementation - that is, whether we should
|
||||
/// invoke it like the reference kernel. This is mainly used with `run()` to
|
||||
/// differentiate which implementation that should be invoked.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvInstance = requires(Conv& conv,
|
||||
const void* input,
|
||||
const void* weight,
|
||||
void* output,
|
||||
ck::utils::conv::ConvParam param) {
|
||||
{ conv.Run(input, weight, output, param) };
|
||||
};
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
///
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f for reference).
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> &&
|
||||
// TODO: Maybe we can unify this implementation for bwd/weight too?
|
||||
// for now, just concern outselves with reference and see when the
|
||||
// rest of the bwd/weight plumbing is there.
|
||||
ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
// We don't want to compute the output dims manually, just get
|
||||
// them via the existing infrastructure
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
// TODO: The reference convolution is currently missing a few features.
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
if(!args.make_weight_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
if(!args.make_output_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
|
||||
conv.Run(inputs.input, inputs.weight, outputs.output, param);
|
||||
return std::make_tuple(true, 0.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/testing/type_traits.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <string>
|
||||
#include <iosfwd>
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
@@ -288,6 +292,57 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
Outputs<SIGNATURE> actual,
|
||||
Outputs<SIGNATURE> expected) = delete;
|
||||
|
||||
/// @brief This structure represents the result of a run operation.
|
||||
///
|
||||
/// The structure contains multiple fields with information about
|
||||
/// how the operation completed (or not). See those for more info.
|
||||
struct RunResult
|
||||
{
|
||||
/// If this value is not set to `std::nullopt`, there was a problem
|
||||
/// while running the algorithm. In this case, the outputs are not
|
||||
/// valid (though may be partially or completely overwritten), and
|
||||
/// the optional contains a short debug message that indicates the
|
||||
/// problem.
|
||||
std::optional<std::string> error = std::nullopt;
|
||||
|
||||
/// The runtime of the kernel in milliseconds, if measured. Whether the
|
||||
/// runtime is measured at all depends on the stream configuration
|
||||
/// passed to run(). 0 if not measured or if there was an error. This
|
||||
/// value is averaged over the total amount of runs actually done. Again,
|
||||
/// this is usually configured via the stream config.
|
||||
float runtime = 0.f;
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an unsupported operation.
|
||||
///
|
||||
/// @param msg A short debug message that will be included in the result.
|
||||
constexpr static RunResult not_supported(std::string_view msg)
|
||||
{
|
||||
return RunResult{.error = std::string(msg)};
|
||||
}
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an average runtime,
|
||||
/// indicating a successful operation.
|
||||
///
|
||||
/// @param runtime The runtime of the kernel in milliseconds.
|
||||
constexpr static RunResult from_runtime(const float runtime)
|
||||
{
|
||||
return RunResult{.runtime = runtime};
|
||||
}
|
||||
|
||||
/// @brief Returns whether this algorithm executed successfully.
|
||||
///
|
||||
/// In this case there should be no message in `error`.
|
||||
bool is_supported() const { return !this->error.has_value(); }
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const RunResult& result)
|
||||
{
|
||||
if(result.error.has_value())
|
||||
return os << "invalid run (" << result.error.value() << ")";
|
||||
else
|
||||
return os << "successful run (" << result.runtime << " ms)";
|
||||
}
|
||||
|
||||
/// @brief Invoke a device operation created by CK Builder.
|
||||
///
|
||||
/// This is the main function used to invoke a particular device operation
|
||||
@@ -318,13 +373,14 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
/// @param outputs The output tensor data. The contents will be overwritten by
|
||||
/// this function.
|
||||
/// @param s_conf Stream config used to launch kernel.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see RunResult
|
||||
template <auto SIGNATURE, typename Operation, typename StreamConf>
|
||||
std::tuple<bool, float> run(Operation& operation,
|
||||
[[nodiscard]] RunResult run(Operation& operation,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
|
||||
/// testing.hpp requires developers of a type of SIGNATURE to implement
|
||||
/// quite a lot of functionality for each SIGNATURE. For example, next
|
||||
/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define
|
||||
|
||||
@@ -51,6 +51,9 @@ struct ValidationReport
|
||||
/// The number of elements which were bitwise 0.
|
||||
uint64_t zero_elements;
|
||||
|
||||
// Max error.
|
||||
double max_error;
|
||||
|
||||
/// @brief Check whether both the output and reference tensor were both all zeros.
|
||||
///
|
||||
/// If both tensors are all zero, it indicates either an incorrect testing setup
|
||||
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
// Initial pass: count errors
|
||||
|
||||
// Allocate and reset counter
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
|
||||
|
||||
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
|
||||
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
|
||||
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
|
||||
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
|
||||
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
const auto r = static_cast<double>(type_convert<float>(b));
|
||||
const auto err = std::abs(o - r);
|
||||
|
||||
atomicMax(d_max_error, err);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
// We expect the number of errors to be very low, so just use an atomic
|
||||
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
uint64_t zero_count = 0;
|
||||
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
double max_error = 0;
|
||||
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
|
||||
|
||||
// TODO: Gather detailed coordinates.
|
||||
|
||||
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
.wrong_elements = error_count,
|
||||
.total_elements = descriptor.get_element_size(),
|
||||
.zero_elements = zero_count,
|
||||
.max_error = max_error,
|
||||
});
|
||||
|
||||
return reports_.back().is_ok();
|
||||
|
||||
@@ -157,6 +157,7 @@ enum class PipelineVersion
|
||||
V3,
|
||||
V4,
|
||||
V5,
|
||||
V6,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver)
|
||||
case V3: return "V3";
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case V6: return "V6";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
|
||||
@@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
)
|
||||
)
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
set(BWD_WEIGHT_TESTS
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
|
||||
@@ -30,14 +37,58 @@ constexpr auto ALGORITHM =
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"GNWC,GKXC,GNWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Execution)
|
||||
{
|
||||
if(!ck_tile::get_device_name().starts_with("gfx9"))
|
||||
{
|
||||
// Note: XDL kernel
|
||||
GTEST_SKIP() << "unsupported architecture";
|
||||
}
|
||||
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 16,
|
||||
.groups = 1,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 64},
|
||||
.filter = {.width = 1},
|
||||
},
|
||||
.filter_strides = {.width = 1},
|
||||
.filter_dilation = {.width = 1},
|
||||
.input_left_pad = {.width = 0},
|
||||
.input_right_pad = {.width = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
@@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
@@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
TEST(Fwd2DFp16_CShufV3_GNHWC, Execution)
|
||||
{
|
||||
if(!ck_tile::get_device_name().starts_with("gfx9"))
|
||||
{
|
||||
// Note: XDL kernel
|
||||
GTEST_SKIP() << "unsupported architecture";
|
||||
}
|
||||
|
||||
@@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
ckt::run(ref_conv, args, inputs.get(), reference.get());
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -1,35 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(cku::TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(ckt::TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Create)
|
||||
{
|
||||
constexpr ConvSignature BwdWeightConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto BwdWeightConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
cku::run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
"NHWGC_GKYXC_NHWGK",
|
||||
@@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
|
||||
{
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 2,
|
||||
.groups = 4,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 32, .height = 56},
|
||||
.filter = {.width = 3, .height = 3},
|
||||
},
|
||||
.filter_strides = {.width = 1, .height = 1},
|
||||
.filter_dilation = {.width = 1, .height = 1},
|
||||
.input_left_pad = {.width = 0, .height = 0},
|
||||
.input_right_pad = {.width = 0, .height = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
@@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
@@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd)
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
ckt::run(ref_conv, args, inputs.get(), reference.get());
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get()));
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -5,11 +5,14 @@
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ck_tile::test::HipError;
|
||||
using ck_tile::test::HipSuccess;
|
||||
using ck_tile::test::InstanceMatcher;
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::StringEqWithDiff;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
TEST(InstanceSet, FromFactory)
|
||||
{
|
||||
@@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic)
|
||||
EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue)));
|
||||
EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue)));
|
||||
}
|
||||
|
||||
TEST(RunResultMatcher, Basic)
|
||||
{
|
||||
EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun());
|
||||
EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun()));
|
||||
}
|
||||
|
||||
TEST(RunResultMatcher, ExplainMatchResult)
|
||||
{
|
||||
testing::StringMatchResultListener listener;
|
||||
EXPECT_TRUE(!ExplainMatchResult(
|
||||
SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener));
|
||||
EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error"));
|
||||
}
|
||||
|
||||
@@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const
|
||||
return ::testing::MakeMatcher(new HipStatusMatcher(error));
|
||||
}
|
||||
|
||||
bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual,
|
||||
::testing::MatchResultListener* listener) const
|
||||
{
|
||||
if(actual.error.has_value() && listener)
|
||||
*listener << "run failed: " << actual.error.value();
|
||||
|
||||
return actual.is_supported();
|
||||
}
|
||||
|
||||
void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; }
|
||||
|
||||
void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; }
|
||||
|
||||
::testing::Matcher<builder::test::RunResult> SuccessfulRun()
|
||||
{
|
||||
return ::testing::MakeMatcher(new RunResultMatcher());
|
||||
}
|
||||
|
||||
} // namespace ck_tile::test
|
||||
|
||||
@@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
|
||||
/// @param error The error to expect.
|
||||
::testing::Matcher<hipError_t> HipError(hipError_t error);
|
||||
|
||||
/// @brief RunResult matcher
|
||||
///
|
||||
/// `ckt::run` returns a RunResult which indicates whether there was any
|
||||
/// problem while running the algorithm. This matcher is used to match those
|
||||
/// values.
|
||||
struct RunResultMatcher : public ::testing::MatcherInterface<builder::test::RunResult>
|
||||
{
|
||||
bool MatchAndExplain(builder::test::RunResult actual,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
void DescribeTo(std::ostream* os) const override;
|
||||
void DescribeNegationTo(std::ostream* os) const override;
|
||||
};
|
||||
|
||||
/// @brief Construct a Google Test matcher that checks that a ckt::run result
|
||||
/// was successful.
|
||||
::testing::Matcher<builder::test::RunResult> SuccessfulRun();
|
||||
|
||||
template <auto SIGNATURE>
|
||||
struct ReferenceOutputMatcher
|
||||
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
|
||||
@@ -180,6 +197,21 @@ struct ReferenceOutputMatcher
|
||||
if(listener->IsInterested() && !errors.empty())
|
||||
{
|
||||
*listener << errors.size() << " tensors failed to validate";
|
||||
|
||||
for(const auto& e : errors)
|
||||
{
|
||||
*listener << "\n - " << e.tensor_name << ": ";
|
||||
|
||||
if(e.is_all_zero())
|
||||
*listener << "all elements in actual and expected tensors are zero";
|
||||
else
|
||||
{
|
||||
// Round to 2 digits
|
||||
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
|
||||
*listener << e.wrong_elements << "/" << e.total_elements
|
||||
<< " incorrect elements (~" << percentage << "%)";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.empty();
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
@@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect)
|
||||
testing::StringMatchResultListener listener;
|
||||
EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener));
|
||||
|
||||
EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate"));
|
||||
EXPECT_THAT(listener.str(),
|
||||
StringEqWithDiff( //
|
||||
"1 tensors failed to validate\n"
|
||||
" - a: 625/625 incorrect elements (~100%)"));
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
auto conv = Instance{};
|
||||
bool is_supported;
|
||||
float avg_time;
|
||||
std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(is_supported, avg_time, conv.GetInstanceString());
|
||||
auto conv = Instance{};
|
||||
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());
|
||||
|
||||
225
include/ck/BUILD_TIME_OPTIMIZATION.md
Normal file
225
include/ck/BUILD_TIME_OPTIMIZATION.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# Build Time Optimization
|
||||
|
||||
Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575)
|
||||
|
||||
This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase.
|
||||
|
||||
## Why Build Time Matters
|
||||
|
||||
Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time.
|
||||
|
||||
## Key Types
|
||||
|
||||
This codebase uses compile-time types to enable zero-overhead abstractions:
|
||||
|
||||
- `Number<N>` - compile-time integer, enables static dispatch and compile-time arithmetic
|
||||
- `Sequence<Is...>` - compile-time integer sequence, used for dimension ordering and index manipulation
|
||||
- `Tuple<Ts...>` - heterogeneous container holding different types, used for tensor descriptors and transforms
|
||||
|
||||
These types allow the compiler to fully unroll loops, eliminate branches, and inline all operations - producing GPU kernels with no runtime abstraction cost.
|
||||
|
||||
## Optimization Techniques
|
||||
|
||||
### 1. Replace Recursive Templates with Pack Expansion
|
||||
|
||||
Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next:
|
||||
|
||||
```
|
||||
sequence_gen_impl<5, F>
|
||||
→ sequence_gen_impl<4, F>
|
||||
→ sequence_gen_impl<3, F>
|
||||
→ ...
|
||||
```
|
||||
|
||||
Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation.
|
||||
|
||||
**Before** (O(N) recursive instantiation):
|
||||
|
||||
```cpp
|
||||
template <index_t N, typename F, index_t... Is>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
using type = typename sequence_gen_impl<N-1, F, F{}(Number<N-1>{}), Is...>::type;
|
||||
};
|
||||
|
||||
template <typename F, index_t... Is>
|
||||
struct sequence_gen_impl<0, F, Is...>
|
||||
{
|
||||
using type = Sequence<Is...>;
|
||||
};
|
||||
```
|
||||
|
||||
**After** (constant depth using compiler intrinsic + pack expansion):
|
||||
|
||||
```cpp
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct sequence_gen_helper
|
||||
{
|
||||
// Apply functor F to all indices via pack expansion
|
||||
// F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number<N-1>{})
|
||||
template <typename F>
|
||||
using apply = Sequence<F{}(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
// __make_integer_seq<sequence_gen_helper, index_t, N> produces
|
||||
// sequence_gen_helper<index_t, 0, 1, ..., N-1> with constant depth
|
||||
using type =
|
||||
typename __make_integer_seq<detail::sequence_gen_helper, index_t, N>::template apply<F>;
|
||||
};
|
||||
```
|
||||
|
||||
Note: This document assumes C++17 or later. While `std::make_integer_sequence` (introduced in C++14) is the standard library facility for generating integer sequences, it only produces `std::integer_sequence<T, ...>`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly.
|
||||
|
||||
### 2. Replace Lambdas with Named Functors
|
||||
|
||||
Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses.
|
||||
|
||||
**Before** (lambda creates unique instantiations at each call site):
|
||||
|
||||
```cpp
|
||||
// The lambda inside transform_tensor_descriptor:
|
||||
generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{});
|
||||
```
|
||||
|
||||
**After** (named functor shares instantiations):
|
||||
|
||||
```cpp
|
||||
// Define functor once
|
||||
struct generate_identity_sequence
|
||||
{
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto operator()(Number<I>) const
|
||||
{
|
||||
return Sequence<I>{};
|
||||
}
|
||||
};
|
||||
|
||||
// Use everywhere - shares instantiations
|
||||
generate_tuple(generate_identity_sequence{}, Number<N>{});
|
||||
```
|
||||
|
||||
This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction).
|
||||
|
||||
**Example: container_concat**
|
||||
|
||||
```cpp
|
||||
// Before: lambda creates unique type per call site
|
||||
// (unpack2 applies a functor to all elements from both tuples)
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2([](auto&&... zs) { return make_tuple(forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
// After: named functor shares instantiations
|
||||
struct make_tuple_functor
|
||||
{
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto operator()(Ts&&... xs) const
|
||||
{
|
||||
return make_tuple(forward<Ts>(xs)...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(make_tuple_functor{}, tx, ty);
|
||||
}
|
||||
```
|
||||
|
||||
This reduced `container_concat` instantiations from 186 to 93 (50% reduction).
|
||||
|
||||
**Example: make_uniform_tuple**
|
||||
|
||||
For patterns that create tuples with repeated values:
|
||||
|
||||
```cpp
|
||||
// Before: unique lambda type at each call site
|
||||
generate_tuple([](auto) { return some_value; }, Number<N>{});
|
||||
|
||||
// After: dedicated helper function
|
||||
template <index_t N, typename T>
|
||||
__host__ __device__ constexpr auto make_uniform_tuple(T&& value)
|
||||
{
|
||||
return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
// Usage
|
||||
make_uniform_tuple<N>(some_value);
|
||||
```
|
||||
|
||||
### 3. Use Constexpr Loops Instead of Template Recursion
|
||||
|
||||
Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation.
|
||||
|
||||
**Before** (O(N) template instantiations):
|
||||
|
||||
```cpp
|
||||
// Simplified example - actual CK code used more complex recursive patterns
|
||||
template <index_t Target, typename Seq, index_t Pos, bool AtEnd>
|
||||
struct find_source_index_impl
|
||||
{
|
||||
static constexpr index_t value =
|
||||
(Seq::template At<Pos>() == Target) ? Pos : find_source_index_impl<Target, Seq, Pos+1, (Pos+1 == Seq::Size())>::value;
|
||||
};
|
||||
|
||||
template <index_t Target, typename Seq, index_t Pos>
|
||||
struct find_source_index_impl<Target, Seq, Pos, true>
|
||||
{
|
||||
static constexpr index_t value = -1; // not found
|
||||
};
|
||||
```
|
||||
|
||||
**After** (single instantiation with constexpr loop):
|
||||
|
||||
```cpp
|
||||
template <index_t Target, index_t... Is>
|
||||
__host__ __device__ constexpr index_t find_source_index(Sequence<Is...>)
|
||||
{
|
||||
// Simplified example - actual implementation handles empty sequences
|
||||
constexpr index_t values[] = {Is...};
|
||||
for(index_t i = 0; i < sizeof...(Is); ++i)
|
||||
if(values[i] == Target) return i;
|
||||
return -1; // not found
|
||||
}
|
||||
```
|
||||
|
||||
This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%.
|
||||
|
||||
### 4. Use Fold Expressions for Accumulation
|
||||
|
||||
Fold expressions (C++17) can replace recursive template patterns for accumulation operations.
|
||||
|
||||
**Before** (uses helper utilities that hide template recursion: `generate_tuple` recursively constructs a tuple of N elements, and `container_reduce` recursively reduces that tuple):
|
||||
|
||||
```cpp
|
||||
const auto element_space_size = container_reduce(
|
||||
generate_tuple([&](auto i) {
|
||||
return (lengths[i] - Number<1>{}) * strides[i];
|
||||
}, Number<N>{}),
|
||||
math::plus{}, Number<1>{});
|
||||
```
|
||||
|
||||
**After** (single fold expression):
|
||||
|
||||
```cpp
|
||||
template <typename... Lengths, typename... Strides, index_t... Is>
|
||||
__host__ __device__ constexpr auto compute_element_space_size(
|
||||
const Tuple<Lengths...>& lengths,
|
||||
const Tuple<Strides...>& strides,
|
||||
Sequence<Is...>)
|
||||
{
|
||||
return (LongNumber<1>{} + ... +
|
||||
((lengths[Number<Is>{}] - Number<1>{}) * strides[Number<Is>{}]));
|
||||
}
|
||||
```
|
||||
|
||||
This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%.
|
||||
@@ -55,9 +55,6 @@
|
||||
#ifndef CK_ENABLE_FP32
|
||||
#define CK_ENABLE_FP32 "ON"
|
||||
#endif
|
||||
#ifndef CK_ENABLE_TF32
|
||||
#define CK_ENABLE_TF32 "ON"
|
||||
#endif
|
||||
#ifndef CK_ENABLE_FP64
|
||||
#define CK_ENABLE_FP64 "ON"
|
||||
#endif
|
||||
@@ -88,10 +85,6 @@
|
||||
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_TF32
|
||||
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
|
||||
#endif
|
||||
|
||||
#ifndef CK_ENABLE_FP64
|
||||
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
|
||||
#endif
|
||||
|
||||
@@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal
|
||||
// check if src element is valid
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
oob_thread_scratch_.template SetAsType<bool>(vgpr_data_idx_seq, is_src_valid);
|
||||
|
||||
// Vector length of elementwise operation
|
||||
constexpr auto get_elem_op_vec_len = []() {
|
||||
@@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal
|
||||
using dst_vector_type = vector_type_maker_t<DstData, VectorSize>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, VectorSize>::type::type;
|
||||
|
||||
dst_vector_type op_r_v;
|
||||
|
||||
// Load data from memory in src_vector first
|
||||
src_vector_container src_vector =
|
||||
src_vector_container{grid_buf.template Get<src_vector_container_t, DoTranspose>(
|
||||
src_coord_.GetOffset(), true)};
|
||||
auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0;
|
||||
src_vector_container src_vector = src_vector_container{
|
||||
grid_buf.template Get<src_vector_container_t, DoTranspose>(index, true)};
|
||||
|
||||
// apply the src elementwise op and convert to DstData under the hood if needed
|
||||
static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) {
|
||||
@@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal
|
||||
// store result in dvgpr_ (static array holding loaded data).
|
||||
// At this point data is already converted to DstData type and
|
||||
// the elementwise operation has been applied
|
||||
dvgpr_.template SetAsType<dst_vector_t>(
|
||||
vgpr_data_idx_seq,
|
||||
is_src_valid ? op_r_v.template AsType<dst_vector_t>()[I0] : vector_t(0));
|
||||
src_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq,
|
||||
op_r_v.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
constexpr auto ordered_fwd_step = StepsPerIteration{};
|
||||
|
||||
// OOB check
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// calculate src data index and make sequence
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; });
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order);
|
||||
}();
|
||||
|
||||
// make sequence to access vgpr data. Add zero as last element of src_data_idx_seq
|
||||
constexpr auto vgpr_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value < src_data_idx.Size())
|
||||
{
|
||||
return Number<src_data_idx[i]>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<src_data_idx.Size() + 1>{});
|
||||
|
||||
auto op_r = src_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq);
|
||||
const bool is_src_valid =
|
||||
oob_thread_scratch_.template GetAsType<bool>(vgpr_data_idx_seq);
|
||||
auto op_r_v = is_src_valid ? op_r : dst_vector_t(0);
|
||||
dst_dvgpr_.template SetAsType<dst_vector_t>(vgpr_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// make forward steps
|
||||
// forward step for each iteration just add 1
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
@@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
true,
|
||||
dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
dst_dvgpr_.template GetAsType<dst_vector_t>(vgpr_data_idx_seq));
|
||||
|
||||
// For each dimension move fwd, bwd or don't move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto access_lengths_as_tuple =
|
||||
container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{});
|
||||
|
||||
return make_naive_tensor_descriptor_packed(access_lengths_as_tuple);
|
||||
}
|
||||
|
||||
static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){};
|
||||
using ThreadScratchData = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
@@ -396,7 +435,17 @@ struct ThreadGroupTransferGlobal
|
||||
decltype(thread_data_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData dvgpr_;
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
using OOBThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool,
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
ThreadScratchData src_dvgpr_;
|
||||
ThreadScratchData dst_dvgpr_;
|
||||
OOBThreadScratch oob_thread_scratch_;
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
const ElementwiseOperation element_op_;
|
||||
|
||||
@@ -11,8 +11,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -11,8 +11,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
|
||||
@@ -3,77 +3,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
Tuple<>{}, // p_d0s_grid
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
Tuple<>{}, // p_d1s_grid
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.c_element_op,
|
||||
arg.block_2_ctile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// MN = MK * KL * LN
|
||||
// ^^^^^^ (Acc0)
|
||||
@@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
using DeviceGemmGemmCommonBase =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
Tuple<>, // D0sLayout
|
||||
B1Layout,
|
||||
Tuple<>, // D1sLayout
|
||||
CLayout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
Tuple<>, // D0sDataType
|
||||
Tuple<>, // D1sDataType
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
false>; // IsMultiD
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
@@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::AGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B0GridDesc,
|
||||
Tuple<>, // Ds0GridDesc
|
||||
B1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B1GridDesc,
|
||||
Tuple<>, // Ds1GridDesc
|
||||
CGridDesc_M_N,
|
||||
typename DeviceGemmGemmCommonBase::CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
@@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
|
||||
DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
Tuple<>, // D0sLayout
|
||||
B1Layout,
|
||||
Tuple<>, // D1sLayout
|
||||
CLayout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
Tuple<>, // D0sDataType,
|
||||
Tuple<>, // D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
false>; // IsMultiD
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmGemmCommon::Invoker;
|
||||
|
||||
// Argument
|
||||
using Argument = typename DeviceGemmGemmCommon::Argument;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
c_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
CDataType* p_c_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
arr3 c_g_m_o_lengths;
|
||||
arr3 c_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
@@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
std::array<const void*, DeviceGemmGemmCommonBase::NumD0Tensor> p_d0_grid{};
|
||||
std::array<const void*, DeviceGemmGemmCommonBase::NumD1Tensor> p_d1_grid{};
|
||||
std::array<index_t, DeviceGemmGemmCommonBase::NumD0Tensor> StrideD0s{}, BatchStrideD0s{};
|
||||
std::array<index_t, DeviceGemmGemmCommonBase::NumD1Tensor> StrideD1s, BatchStrideD1s{};
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0_grid,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1_grid,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -0,0 +1,902 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <cstdarg>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/utility/integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp,
|
||||
typename GridwiseOp,
|
||||
bool HasMainKBlockLoop,
|
||||
TailNumber TailNum,
|
||||
bool IsMultiD>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_e1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx)));
|
||||
|
||||
auto [p_d0s_grid, p_d1s_grid] = [&]() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) {
|
||||
static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) {
|
||||
const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In));
|
||||
grid_pointer(In) = arg_grid(In) + batch_offset;
|
||||
});
|
||||
return std::move(grid_pointer);
|
||||
};
|
||||
auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) {
|
||||
return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx);
|
||||
};
|
||||
auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) {
|
||||
return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx);
|
||||
};
|
||||
auto d0s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD0Tensor>{},
|
||||
get_d0_base_ptr,
|
||||
arg.p_d0s_grid,
|
||||
GridwiseOp::MakeD0sGridPointer());
|
||||
auto d1s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD1Tensor>{},
|
||||
get_d1_base_ptr,
|
||||
arg.p_d1s_grid,
|
||||
GridwiseOp::MakeD1sGridPointer());
|
||||
return std::make_pair(d0s_grid, d1s_grid);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(Tuple<>{}, Tuple<>{});
|
||||
}
|
||||
}();
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
p_d0s_grid,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
arg.p_c_e1_grid + c_e1_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.cde1_element_op,
|
||||
arg.block_2_etile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
template <typename DeviceOp,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename B0layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename CE1Layout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename AccDataType,
|
||||
typename CE1DataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool IsMultiD = false>
|
||||
struct DeviceGemmGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
static constexpr ck::index_t NumD0Tensor = []() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
return DeviceOp::NumD0Tensor;
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
static constexpr ck::index_t NumD1Tensor = []() {
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
return DeviceOp::NumD1Tensor;
|
||||
}
|
||||
return 0;
|
||||
}();
|
||||
|
||||
struct GridDescriptorCreator
|
||||
{
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD0sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD1sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
|
||||
}
|
||||
};
|
||||
|
||||
using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {}));
|
||||
using D0sGridDesc =
|
||||
remove_cvref_t<decltype(GridDescriptorCreator::MakeD0sGridDescriptor({}, {}))>;
|
||||
using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {}));
|
||||
using D1sGridDesc =
|
||||
remove_cvref_t<decltype(GridDescriptorCreator::MakeD1sGridDescriptor({}, {}))>;
|
||||
using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N =
|
||||
decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_E1_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideC_E1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_E1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d0_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideC_E1_;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename DeviceOp,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename B0layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename CE1Layout,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename AccDataType,
|
||||
typename CE1DataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool IsMultiD = false>
|
||||
struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg
|
||||
{
|
||||
using GridwiseGemm = typename DeviceOp::GridwiseOp;
|
||||
using Common =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
CE1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
CE1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
IsMultiD>;
|
||||
|
||||
static constexpr auto NumD0Tensor = Common::NumD0Tensor;
|
||||
static constexpr auto NumD1Tensor = Common::NumD1Tensor;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
Argument(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid_,
|
||||
CE1DataType* p_e1_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CDE1ElementwiseOperation cde1_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_d0s_grid{},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_d1s_grid{},
|
||||
p_c_e1_grid{p_e1_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
cde1_element_op{cde1_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
e1_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths,
|
||||
a_g_m_k_strides);
|
||||
b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths,
|
||||
b0_g_n_k_strides);
|
||||
b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths,
|
||||
b1_g_o_n_strides);
|
||||
c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor(
|
||||
e1_g_m_o_lengths, e1_g_m_o_strides);
|
||||
c_e1_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_e1_grid_desc_m_n);
|
||||
block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1);
|
||||
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0s layout [batch_count, M, N]
|
||||
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
|
||||
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
|
||||
});
|
||||
// D0 desc
|
||||
d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor(
|
||||
d0s_g_m_n_lengths, d0s_g_m_n_strides);
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1s layout [batch_count, M, O]
|
||||
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
|
||||
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
|
||||
});
|
||||
// D1 desc
|
||||
d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor(
|
||||
d1s_g_m_o_lengths, d1s_g_m_o_strides);
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
d1s_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
typename GridwiseGemm::D1sGridPointer p_d1s_grid;
|
||||
CE1DataType* p_c_e1_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
|
||||
arr3 e1_g_m_o_lengths;
|
||||
arr3 e1_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CDE1ElementwiseOperation cde1_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
typename Common::AGridDesc a_grid_desc;
|
||||
typename Common::B0GridDesc b0_grid_desc;
|
||||
std::conditional_t<IsMultiD, typename Common::D0sGridDesc, Tuple<>> d0s_grid_desc;
|
||||
typename Common::B1GridDesc b1_grid_desc;
|
||||
typename Common::D1sGridDesc d1s_grid_desc;
|
||||
std::conditional_t<
|
||||
IsMultiD,
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Tuple<>>
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
std::conditional_t<IsMultiD, typename Common::E1GridDesc, typename Common::CGridDesc_M_N>
|
||||
c_e1_grid_desc_m_n;
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map;
|
||||
|
||||
typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tail_num = decltype(tail_number)::value;
|
||||
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
|
||||
GridwiseGemm,
|
||||
has_loop,
|
||||
tail_num,
|
||||
IsMultiD>;
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static constexpr bool CheckDLayout()
|
||||
{
|
||||
bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CE1Layout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(IsMultiD)
|
||||
{
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B0 layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D0s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D1s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc,
|
||||
arg.c_e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto cde1_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
|
||||
? arg.b0_g_n_k_strides[2]
|
||||
: arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? arg.b1_g_o_n_strides[2]
|
||||
: arg.b1_g_o_n_strides[1];
|
||||
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
|
||||
// and the lowest dimension stride is hardcoded to 1
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
e1_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
|
||||
? arg.b0_g_n_k_strides[2]
|
||||
: arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? arg.b1_g_o_n_strides[2]
|
||||
: arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,91 +3,20 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t e1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx)));
|
||||
|
||||
auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer();
|
||||
auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer();
|
||||
|
||||
static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) {
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
|
||||
p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
p_d0s_grid,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
arg.p_e1_grid + e1_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.cde1_element_op,
|
||||
arg.block_2_etile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes:
|
||||
// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...)
|
||||
// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...)
|
||||
@@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD0sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeD1sGridDescriptor(
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
|
||||
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
|
||||
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
|
||||
{
|
||||
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using D0sGridDesc = remove_cvref_t<decltype(MakeD0sGridDescriptor({}, {}))>;
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using D1sGridDesc = remove_cvref_t<decltype(MakeD1sGridDescriptor({}, {}))>;
|
||||
using E1GridDesc = decltype(MakeE1GridDescriptor({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA0_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideE1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA0_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideE1_;
|
||||
};
|
||||
using DeviceGemmGemmCommonBase =
|
||||
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
E1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>; // IsMultiD
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
@@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
CDE1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
D0sGridDesc,
|
||||
B1GridDesc,
|
||||
D1sGridDesc,
|
||||
E1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::AGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B0GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::D0sGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::B1GridDesc,
|
||||
typename DeviceGemmGemmCommonBase::D1sGridDesc,
|
||||
typename DeviceGemmGemmCommonBase::E1GridDesc,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
@@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
|
||||
DeviceOp,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
B0layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
E1DataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
AK1,
|
||||
BK1,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
BlkGemmPipelineVer,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
true>; // IsMultiD
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmGemmCommon::Invoker;
|
||||
|
||||
// Argument
|
||||
using Argument = typename DeviceGemmGemmCommon::Argument;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid_,
|
||||
E1DataType* p_e1_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CDE1ElementwiseOperation cde1_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_d0s_grid{},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_d1s_grid{},
|
||||
p_e1_grid{p_e1_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
cde1_element_op{cde1_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
e1_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides);
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e1_grid_desc_m_n);
|
||||
block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1);
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0s layout [batch_count, M, N]
|
||||
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
|
||||
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
|
||||
|
||||
// D0 desc
|
||||
d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1s layout [batch_count, M, O]
|
||||
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
|
||||
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
|
||||
|
||||
// D1 desc
|
||||
d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]);
|
||||
});
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc);
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
typename GridwiseOp::D0sGridPointer p_d0s_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
typename GridwiseOp::D1sGridPointer p_d1s_grid;
|
||||
E1DataType* p_e1_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
|
||||
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
|
||||
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
|
||||
arr3 e1_g_m_o_lengths;
|
||||
arr3 e1_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CDE1ElementwiseOperation cde1_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
D0sGridDesc d0s_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
D1sGridDesc d1s_grid_desc;
|
||||
typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
E1GridDesc e1_grid_desc_m_n;
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
// check if DsLayout is supported
|
||||
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
|
||||
static constexpr bool CheckDLayout()
|
||||
{
|
||||
bool valid = true;
|
||||
// iterate over DLayout tuple
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
// if RefLayout and DLayout are same, keep valid true, otherwise false
|
||||
valid = valid && is_same_v<RefLayout, DLayout>;
|
||||
});
|
||||
return valid;
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B0 layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D0s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
|
||||
{
|
||||
print("DeviceOp: All D1s layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<E1Layout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.d0s_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.d1s_grid_desc,
|
||||
arg.e1_grid_desc_m_n,
|
||||
arg.block_2_etile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto cde1_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
|
||||
|
||||
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
|
||||
// and the lowest dimension stride is hardcoded to 1
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
e1_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3<DeviceOp,
|
||||
GridwiseOp,
|
||||
has_loop,
|
||||
tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a0,
|
||||
const B0DataType* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
@@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
{
|
||||
return RawArg{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
return Argument{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
@@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemmWelford::Argument gemm_arg{
|
||||
std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
|
||||
@@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() &&
|
||||
!GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
|
||||
std::array<const void*, 1>{arg.p_b_grid_},
|
||||
std::array<const void*, 0>{},
|
||||
|
||||
@@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
*dynamic_cast<const typename GridwiseGemm::Argument*>(&arg));
|
||||
}
|
||||
|
||||
@@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
false,
|
||||
false>;
|
||||
false, // PermuteA
|
||||
false, // PermuteB
|
||||
false, // IsBPreShuffled
|
||||
true>; // ForceThreadTileTransfer
|
||||
|
||||
#define GridwiseGemmCTransposeTemplateParameters \
|
||||
ALayout, BLayout, DsLayout, ELayout, Tuple<ADataType>, Tuple<BDataType>, AccDataType, \
|
||||
@@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \
|
||||
AComputeType, false, false
|
||||
AComputeType, false, false, false, true
|
||||
|
||||
using GridwiseGemmCTranspose =
|
||||
std::conditional_t<CTranspose,
|
||||
|
||||
@@ -162,7 +162,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
|
||||
@@ -171,9 +170,11 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
|
||||
const index_t grid_size = gdx * gdy * gdz;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -338,16 +339,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if constexpr(!IsTwoStageNeeded)
|
||||
{
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}, 1, 1));
|
||||
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
ActiveWorkgroupsPerCU()
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return;
|
||||
}
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
int max_occupancy = 0;
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
// TODO: implement
|
||||
}
|
||||
else
|
||||
{
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>,
|
||||
BlockSize,
|
||||
dynamic_smem_size));
|
||||
}
|
||||
max_occupancy_ = std::max(1, max_occupancy);
|
||||
}
|
||||
int max_occupancy_;
|
||||
};
|
||||
|
||||
struct Argument : public BaseArgument, public ArgumentSplitK
|
||||
{
|
||||
Argument(
|
||||
@@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
|
||||
@@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
|
||||
k_batch_ = std::min(k_batch_, k_batch_max);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
||||
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
||||
|
||||
@@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
|
||||
e_g_k_c_xs_strides);
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN;
|
||||
@@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
|
||||
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
|
||||
grid_size);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -568,7 +568,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(split_k < 0)
|
||||
{
|
||||
ck::index_t gemmM, gemmN, gemmK;
|
||||
@@ -585,6 +584,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
|
||||
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
|
||||
|
||||
// Cap k_batch_ to 128 to avoid accuracy issues
|
||||
k_batch_ = std::min(k_batch_, 128);
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
|
||||
@@ -594,7 +596,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
k_batch_ = split_k;
|
||||
}
|
||||
@@ -1373,12 +1374,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
||||
if(arg.k_batch_ < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
// check device
|
||||
if constexpr(DirectLoad)
|
||||
|
||||
@@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
bool supported = true;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(
|
||||
arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(
|
||||
arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
|
||||
@@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
|
||||
if(not group_arg_valid)
|
||||
|
||||
@@ -1631,6 +1631,13 @@ struct ConvInvscale
|
||||
e = type_convert<f8_t>(c / scale_in_ / scale_wei_ / scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
e = type_convert<f8_t>(c_float / scale_in_ / scale_wei_ / scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
@@ -1656,6 +1663,13 @@ struct ConvScale
|
||||
e = type_convert<f8_t>(c * scale_in_ * scale_wei_ * scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
e = type_convert<f8_t>(c_float * scale_in_ * scale_wei_ * scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
@@ -1683,6 +1697,15 @@ struct ConvScaleRelu
|
||||
e = type_convert<f8_t>(x * scale_out_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& e, const f8_t& c) const
|
||||
{
|
||||
const float c_float = type_convert<float>(c);
|
||||
float x;
|
||||
Relu{}.template operator()<float>(x, c_float * scale_in_ * scale_wei_);
|
||||
e = type_convert<f8_t>(x * scale_out_);
|
||||
};
|
||||
|
||||
float scale_in_;
|
||||
float scale_wei_;
|
||||
float scale_out_;
|
||||
|
||||
@@ -132,10 +132,6 @@ struct ABTransferWaveTiles
|
||||
index_t,
|
||||
index_t)
|
||||
{
|
||||
// Notes: padding is currently not supported with transpose
|
||||
static_assert(!((PadMN || PadK) && ABDoTranspose),
|
||||
"padding is currently not supported with transpose");
|
||||
|
||||
const index_t MN_grid = !PadMN ? sizeMN : MNPad;
|
||||
const index_t K_grid = !PadK ? sizeK : KPad;
|
||||
|
||||
|
||||
@@ -362,23 +362,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.wave_size;
|
||||
|
||||
__host__ __device__ static constexpr bool AWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 &&
|
||||
!IsBPreShuffled;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool BWaveTransferApplicable()
|
||||
{
|
||||
return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
}
|
||||
|
||||
// Limitations of the current implementation:
|
||||
// - no multiAB
|
||||
// - GemmSpecialization Default with transpose
|
||||
#ifdef __gfx12__
|
||||
static constexpr bool IsAWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<ALayout, tensor_layout::gemm::RowMajor>) ||
|
||||
is_same_v<ALayout, tensor_layout::gemm::RowMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled;
|
||||
static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsBWaveTransferApplicable =
|
||||
!ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 &&
|
||||
((GemmSpec == tensor_operation::device::GemmSpecialization::Default &&
|
||||
!is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) ||
|
||||
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>) &&
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8;
|
||||
static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable();
|
||||
|
||||
static constexpr bool IsWaveTileInterleavedFitting =
|
||||
(NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize);
|
||||
@@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// Conditions for Wave Transfer with transpose:
|
||||
// - 16 bit type: K % 8 == 0 (4 subtiles of 8x8)
|
||||
// - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16)
|
||||
__host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K)
|
||||
{
|
||||
if constexpr(AWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
pass &= !(sizeof(ADataType_) == 1 &&
|
||||
!(M % (2 * ABlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K)
|
||||
{
|
||||
if constexpr(BWaveTransferApplicable() &&
|
||||
!(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value))
|
||||
{
|
||||
if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool pass = true;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
pass &= !(sizeof(BDataType_) == 1 &&
|
||||
!(N % (2 * BBlockTransferSrcScalarPerVector) == 0));
|
||||
});
|
||||
return pass;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Argument>
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg,
|
||||
|
||||
@@ -199,55 +199,113 @@ template <index_t N>
|
||||
using make_index_sequence =
|
||||
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
|
||||
|
||||
// merge sequence
|
||||
template <typename Seq, typename... Seqs>
|
||||
struct sequence_merge
|
||||
// merge sequence - optimized to avoid recursive instantiation
|
||||
//
|
||||
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
|
||||
// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why:
|
||||
//
|
||||
// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each
|
||||
// element can be computed independently: output[i] = f(i)
|
||||
//
|
||||
// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths.
|
||||
// To compute output[i], we need to know:
|
||||
// 1. Which input sequence contains this index
|
||||
// 2. The offset within that sequence
|
||||
// This requires computing cumulative sequence lengths, which requires recursion/iteration.
|
||||
//
|
||||
// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth:
|
||||
// - Base cases handle 1-4 sequences directly (O(1) for common cases)
|
||||
// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...)
|
||||
// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences
|
||||
//
|
||||
// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to
|
||||
// linear dependency chain, so binary tree is superior.
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
// Helper to concatenate multiple sequences in one step using fold expression
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge_impl;
|
||||
|
||||
// Base case: single sequence
|
||||
template <index_t... Is>
|
||||
struct sequence_merge_impl<Sequence<Is...>>
|
||||
{
|
||||
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
|
||||
using type = Sequence<Is...>;
|
||||
};
|
||||
|
||||
// Two sequences: direct concatenation
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct sequence_merge<Seq>
|
||||
// Three sequences: direct concatenation (avoids one level of recursion)
|
||||
template <index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>>
|
||||
{
|
||||
using type = Seq;
|
||||
using type = Sequence<Xs..., Ys..., Zs...>;
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
// Four sequences: direct concatenation
|
||||
template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds>
|
||||
struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>>
|
||||
{
|
||||
using type = Sequence<As..., Bs..., Cs..., Ds...>;
|
||||
};
|
||||
|
||||
// General case: binary tree reduction (O(log N) depth instead of O(N))
|
||||
template <typename S1, typename S2, typename S3, typename S4, typename... Rest>
|
||||
struct sequence_merge_impl<S1, S2, S3, S4, Rest...>
|
||||
{
|
||||
// Merge pairs first, then recurse
|
||||
using left = typename sequence_merge_impl<S1, S2>::type;
|
||||
using right = typename sequence_merge_impl<S3, S4, Rest...>::type;
|
||||
using type = typename sequence_merge_impl<left, right>::type;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename detail::sequence_merge_impl<Seqs...>::type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct sequence_merge<>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation
|
||||
namespace detail {
|
||||
|
||||
// Helper that applies functor F to indices and produces a Sequence
|
||||
// __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1,
|
||||
// ..., N-1>
|
||||
template <typename T, T... Is>
|
||||
struct sequence_gen_helper
|
||||
{
|
||||
// Apply a functor F to all indices at once via pack expansion (O(1) depth)
|
||||
template <typename F>
|
||||
using apply = Sequence<F{}(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
template <index_t IBegin, index_t NRemain, typename G>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
using type =
|
||||
typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>;
|
||||
};
|
||||
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 1, G>
|
||||
{
|
||||
static constexpr index_t Is = G{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 0, G>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
template <typename F>
|
||||
struct sequence_gen<0, F>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// arithmetic sequence
|
||||
@@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1>
|
||||
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
// uniform sequence - optimized using __make_integer_seq
|
||||
namespace detail {
|
||||
|
||||
template <typename T, T... Is>
|
||||
struct uniform_sequence_helper
|
||||
{
|
||||
// Apply a constant value to all indices via pack expansion
|
||||
template <index_t Value>
|
||||
using apply = Sequence<((void)Is, Value)...>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <index_t NSize, index_t I>
|
||||
struct uniform_sequence_gen
|
||||
{
|
||||
struct F
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>::
|
||||
template apply<I>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen<NSize, F>::type;
|
||||
template <index_t I>
|
||||
struct uniform_sequence_gen<0, I>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
|
||||
@@ -20,6 +20,7 @@ struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
||||
using type = Tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
|
||||
template <typename T, index_t N>
|
||||
struct StaticallyIndexedArrayImpl
|
||||
{
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief Maximum number of error values to display when checking errors */
|
||||
constexpr int ERROR_DETAIL_LIMIT = 128;
|
||||
constexpr int ERROR_DETAIL_LIMIT = 16;
|
||||
|
||||
/** @brief 8-bit floating point type */
|
||||
using F8 = ck_tile::fp8_t;
|
||||
|
||||
@@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<1>>{});
|
||||
else
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
|
||||
sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
|
||||
|
||||
@@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// for every column in AQ
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
// for every warp corresponding to a quantization scale
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
@@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GemmTraits>
|
||||
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
|
||||
{
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
static constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
template <index_t KIdx,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
constexpr auto a_lds_load_distr = [&]() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeABlockDistributionEncode()),
|
||||
ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto b_lds_load_distr = [&]() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(MakeBBlockDistributionEncode()),
|
||||
BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
}();
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(ALoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::MPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(BLoadTranspose)
|
||||
return make_tuple(number<KPerInnerLoop>{}, number<GemmTraits::NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<GemmTraits::NPerBlock>{}, number<KPerInnerLoop>{});
|
||||
}();
|
||||
constexpr auto k_idx_offset = KIdx * KPerInnerLoop;
|
||||
constexpr auto a_offset =
|
||||
ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
constexpr auto b_offset =
|
||||
BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset};
|
||||
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr);
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B with quantization support
|
||||
template <typename CBlockTensor,
|
||||
typename AQBlockTensor,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
AQBlockTensor& aq_block_tensor,
|
||||
const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window,
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
|
||||
"The CDataType as defined in traits should be the same as corresponding "
|
||||
"C block tensor data type!");
|
||||
constexpr auto warp_size = get_warp_size();
|
||||
|
||||
// Track which KRepeat chunk is currently loaded
|
||||
index_t current_k_repeat_loaded = -1;
|
||||
|
||||
// Restructured loop: M → N → QScale → KIterPerQScale
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// Iterate over quantization groups
|
||||
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// Accumulate K iterations for this quantization group
|
||||
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
|
||||
// Map quantization indices to global K iteration
|
||||
constexpr auto kIterGlobal =
|
||||
kQScale * Traits::KIterPerQScale + kIterInQScale;
|
||||
|
||||
// Map to KRepeat chunk and KInnerLoopIter offset
|
||||
constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter;
|
||||
constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter;
|
||||
|
||||
// Prefetch new chunk if needed
|
||||
if constexpr(kInnerIdx == 0)
|
||||
{
|
||||
if(current_k_repeat_loaded != kRepeatIdx)
|
||||
{
|
||||
LocalPrefetch<kRepeatIdx>(
|
||||
a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(kRepeatIdx != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
current_k_repeat_loaded = kRepeatIdx;
|
||||
}
|
||||
}
|
||||
|
||||
// Load A warp tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kInnerIdx>{},
|
||||
a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// Load B warp tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_warp_tile_.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kInnerIdx>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// Synchronization barrier at the end of last iteration
|
||||
if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 &&
|
||||
kIterInQScale == Traits::KIterPerQScale - 1 &&
|
||||
mIter.value == MIterPerWarp - 1 &&
|
||||
nIter.value == NIterPerWarp - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Accumulate: first iteration initializes, rest accumulate
|
||||
if constexpr(kIterInQScale == 0)
|
||||
{
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// Set priority for scheduling
|
||||
if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
|
||||
// Apply quantization scale after accumulating all K iterations for this
|
||||
// group
|
||||
constexpr auto tbuf_offset =
|
||||
number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
|
||||
merge_sequences(sequence<mIter, nIter>{},
|
||||
c_warp_y_index_zeros)) /
|
||||
CBlockTensor::PackedSize>{};
|
||||
|
||||
AQPickerCommon<AQBlockTensor, Traits, mIter, kQScale> aq_picker(
|
||||
aq_block_tensor);
|
||||
|
||||
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
|
||||
[&](auto c_row) {
|
||||
float scale_reg_f = aq_picker.template pick<c_row>();
|
||||
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
|
||||
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
@@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow,
|
||||
template <index_t KIdx = 0,
|
||||
typename ASmemBlockWindow,
|
||||
typename BSmemBlockWindow,
|
||||
bool ALoadTranspose = false,
|
||||
bool BLoadTranspose = false>
|
||||
@@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> a_load_tr = {},
|
||||
bool_constant<BLoadTranspose> b_load_tr = {})
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
if constexpr(Scheduler == GemmPipelineScheduler::Interwave)
|
||||
{
|
||||
block_gemm_impl_.template LocalPrefetch<KIdx>(
|
||||
a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
return PipelineImpl<GemmPipelineScheduler::Intrawave>{}
|
||||
.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const OverrideADataType& a) { return a; },
|
||||
[](const BDataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
aq_dram_block_window_tmp,
|
||||
|
||||
@@ -392,8 +392,4 @@ struct BlockReduce2D
|
||||
InDataType reduce_init;
|
||||
};
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -40,7 +40,7 @@ struct BlockSoftmax2D
|
||||
#endif
|
||||
|
||||
// compute row max
|
||||
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
|
||||
auto reduce_row_max = BlockReduce2D<decltype(x)>{x, -numeric<DataType>::infinity()};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
|
||||
#else
|
||||
|
||||
@@ -10,49 +10,55 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized backward data convolution kernel working with packed (contiguous) tensors
|
||||
// Computes gradients w.r.t. input from output gradients and weights
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
// Optimized backward data convolution kernel working with packed (contiguous) tensors with
|
||||
// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes
|
||||
// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (output gradient) tensors
|
||||
index_t NumBExtra, // Number of extra B (weight) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
const OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in,
|
||||
const WeiDataType* const* __restrict__ p_weis,
|
||||
const OutDataType* const* __restrict__ p_outs,
|
||||
const DDataType* const* __restrict__ p_ds,
|
||||
const index_t* const* __restrict__ p_d_strides,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t wo = w_tmp / stride_x;
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
const OutDataType* out_gnk = out_gn;
|
||||
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
|
||||
// Pointers at current filter position
|
||||
const OutDataType* output_grad_g_n_k = output_grad_g_n;
|
||||
const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c;
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnk[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_g_n_k,
|
||||
p_outs + 1,
|
||||
g * out_stride_g + n * out_stride_n,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_g_k_c,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
|
||||
}
|
||||
}
|
||||
@@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
@@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
|
||||
// Pointers at current spatial height and filter Y position
|
||||
const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h;
|
||||
const WeiDataType* weight_at_c_y =
|
||||
weight_g + c * wei_stride_c + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation with extra
|
||||
// A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_h,
|
||||
p_outs + 1,
|
||||
g * out_stride_g + n * out_stride_n + ho * out_stride_h,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c + y * wei_stride_y,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(in_val,
|
||||
in_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
c,
|
||||
hi * p_d_strides[0][3] +
|
||||
wi * p_d_strides[0][4]);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
|
||||
in_val;
|
||||
}
|
||||
@@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
@@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t do_idx = d_tmp / stride_z;
|
||||
if(do_idx >= 0 && do_idx < Do)
|
||||
{
|
||||
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
|
||||
// Pointers at current spatial depth
|
||||
const OutDataType* output_grad_at_d =
|
||||
output_grad_g_n + do_idx * out_stride_d;
|
||||
const WeiDataType* weight_at_c_z =
|
||||
weight_g + c * wei_stride_c + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
@@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
// Pointers at current spatial depth and height
|
||||
const OutDataType* output_grad_at_d_h =
|
||||
output_grad_at_d + ho * out_stride_h;
|
||||
const WeiDataType* weight_at_c_z_y =
|
||||
weight_at_c_z + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val,
|
||||
out_gnkdh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val,
|
||||
wei_gkczy[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation
|
||||
// with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<
|
||||
NumAExtra>(out_val,
|
||||
out_op,
|
||||
output_grad_at_d_h,
|
||||
p_outs + 1,
|
||||
g * out_stride_g +
|
||||
n * out_stride_n +
|
||||
do_idx * out_stride_d +
|
||||
ho * out_stride_h,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with
|
||||
// extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<
|
||||
NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c_z_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c +
|
||||
z * wei_stride_z + y * wei_stride_y,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
in_val,
|
||||
in_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
c,
|
||||
di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
|
||||
hi * in_stride_h + wi] = in_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward data convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference backward data convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -288,15 +382,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_data(TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility
|
||||
void naive_conv_bwd_data_multi_abd(
|
||||
TIn* p_in,
|
||||
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
|
||||
const std::array<const TOut*, NumAElementwise + 1>& p_outs,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
std::vector<SimpleDeviceMem> wei_packed_bufs;
|
||||
wei_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
|
||||
}
|
||||
|
||||
std::vector<SimpleDeviceMem> out_packed_bufs;
|
||||
out_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
out_packed_bufs.emplace_back(out_total * sizeof(TOut));
|
||||
}
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
|
||||
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
std::array<TOut*, NumAElementwise + 1> p_outs_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_outs_packed[i] = static_cast<TOut*>(out_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
@@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
|
||||
SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
|
||||
TOut** d_outs_ptrs = static_cast<TOut**>(outs_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
|
||||
p_weis_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TWei*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs,
|
||||
p_outs_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TOut*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_data_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_data_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_data_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_bwd_data - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void naive_conv_bwd_data(TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TWei*, 1> p_weis = {p_wei};
|
||||
std::array<const TOut*, 1> p_outs = {p_out};
|
||||
std::array<const TIn*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in,
|
||||
p_weis,
|
||||
p_outs,
|
||||
p_ds,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,49 +10,58 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
|
||||
// Optimized backward weight convolution kernel working with packed (contiguous) tensors with
|
||||
// multi-ABD support
|
||||
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
|
||||
// weight_grad[G][K][C][filter]
|
||||
// Computes gradient with respect to weights
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (input) tensors
|
||||
index_t NumBExtra, // Number of extra B (output gradient) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* __restrict__ p_out_grad,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void
|
||||
naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* const* __restrict__ p_out_grads,
|
||||
const DDataType* const* __restrict__ p_ds,
|
||||
const index_t* const* __restrict__ p_d_strides,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gn[wi]);
|
||||
out_op(out_val, out_gn_k[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_n_c,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_n_k,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
|
||||
}
|
||||
}
|
||||
@@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
|
||||
// Pointers at current spatial height
|
||||
const InDataType* input_at_h = input_at_n_c + hi * in_stride_h;
|
||||
const OutDataType* output_grad_at_h =
|
||||
output_grad_at_n_k + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
out_op(out_val, out_gn_kh[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra B
|
||||
// tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_h,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k +
|
||||
ho * out_stride_h,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
@@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(wei_val,
|
||||
wei_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
k,
|
||||
c,
|
||||
y * p_d_strides[0][3] +
|
||||
x * p_d_strides[0][4]);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
|
||||
x] = wei_val;
|
||||
}
|
||||
@@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
|
||||
// Pointers at current spatial depth
|
||||
const InDataType* input_at_d = input_at_n_c + di * in_stride_d;
|
||||
const OutDataType* output_grad_at_d =
|
||||
output_grad_at_n_k + do_idx * out_stride_d;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
|
||||
// Pointers at current spatial depth and height
|
||||
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
|
||||
const OutDataType* output_grad_at_d_h =
|
||||
output_grad_at_d + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
out_op(out_val, out_gn_kdh[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_d_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
di * in_stride_d + hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra
|
||||
// B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_d_h,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k +
|
||||
do_idx * out_stride_d + ho * out_stride_h,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(in_val);
|
||||
}
|
||||
@@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
k,
|
||||
c,
|
||||
z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
|
||||
y * wei_stride_y + x] = wei_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward weight convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference backward weight convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -270,15 +370,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_weight(const TIn* p_in,
|
||||
TWei* p_wei_grad,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility
|
||||
void naive_conv_bwd_weight_multi_abd(
|
||||
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
|
||||
TWei* p_wei_grad,
|
||||
const std::array<const TOut*, NumBElementwise + 1>& p_outs,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
|
||||
std::vector<SimpleDeviceMem> in_packed_bufs;
|
||||
in_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
|
||||
}
|
||||
|
||||
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
|
||||
|
||||
std::vector<SimpleDeviceMem> out_grad_packed_bufs;
|
||||
out_grad_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut));
|
||||
}
|
||||
|
||||
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
|
||||
|
||||
std::array<TOut*, NumBElementwise + 1> p_out_grads_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_out_grads_packed[i] = static_cast<TOut*>(out_grad_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
@@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
|
||||
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_outs[i],
|
||||
p_out_grads_packed[i],
|
||||
d_out_lengths,
|
||||
d_out_strides,
|
||||
dim_count,
|
||||
out_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
|
||||
SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
|
||||
TOut** d_out_grads_ptrs = static_cast<TOut**>(out_grads_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
|
||||
p_ins_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TIn*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs,
|
||||
p_out_grads_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TOut*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_weight_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_bwd_weight - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void
|
||||
naive_conv_bwd_weight(const TIn* p_in,
|
||||
TWei* p_wei_grad,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TIn*, 1> p_ins = {p_in};
|
||||
std::array<const TOut*, 1> p_outs = {p_out};
|
||||
std::array<const TWei*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
|
||||
p_wei_grad,
|
||||
p_outs,
|
||||
p_ds,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,48 +10,56 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized convolution kernel working with packed (contiguous) tensors
|
||||
// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (input) tensors
|
||||
index_t NumBExtra, // Number of extra B (weight) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void naive_conv_fwd_packed_multi_abd(
|
||||
const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra)
|
||||
const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra)
|
||||
const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers
|
||||
const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays
|
||||
OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gc = in_g + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gc[wi]);
|
||||
wei_op(wei_val, wei_gkc[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_c,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
|
||||
}
|
||||
}
|
||||
@@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
|
||||
// Pointers at current spatial height and filter Y position
|
||||
const InDataType* input_at_h = input_at_c + hi * in_stride_h;
|
||||
const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
wei_op(wei_val, wei_gkcy[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
|
||||
y * wei_stride_y,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
@@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(out_val,
|
||||
out_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
k,
|
||||
ho * p_d_strides[0][3] +
|
||||
wo * p_d_strides[0][4]);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
|
||||
out_val;
|
||||
}
|
||||
@@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
|
||||
// Pointers at current spatial depth
|
||||
const InDataType* input_at_d = input_at_c + di * in_stride_d;
|
||||
const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
// Pointers at current spatial depth and height
|
||||
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
|
||||
const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
wei_op(wei_val, wei_gkczy[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_d_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
di * in_stride_d + hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_z_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
|
||||
z * wei_stride_z + y * wei_stride_y,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
out_val,
|
||||
out_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
k,
|
||||
do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
|
||||
ho * out_stride_h + wo] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -266,15 +357,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_fwd(const TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility
|
||||
void naive_conv_fwd_multi_abd(
|
||||
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
|
||||
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
for(auto l : out_lengths)
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
// Allocate packed buffers for all A and B tensors
|
||||
// Use separate allocations to avoid copy assignment issues with RAII wrapper
|
||||
std::vector<SimpleDeviceMem> in_packed_bufs;
|
||||
in_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
|
||||
}
|
||||
|
||||
std::vector<SimpleDeviceMem> wei_packed_bufs;
|
||||
wei_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
|
||||
}
|
||||
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
// Get packed buffer pointers
|
||||
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
@@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
|
||||
// Pack input and weight tensors to contiguous layout
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
// Pack all A tensors
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
}
|
||||
|
||||
// Pack all B tensors
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
// NOTE: D tensors are NOT packed - they are used directly with their original strides
|
||||
// to support broadcasting (e.g., BiasGK layout with zero strides)
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
// Allocate and copy strides to device
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
|
||||
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
|
||||
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
|
||||
p_ins_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TIn*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
|
||||
p_weis_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TWei*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
// D tensors use original pointers (not packed) to support broadcasting
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_fwd_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_fwd_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_fwd_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_fwd - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void naive_conv_fwd(const TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TIn*, 1> p_ins = {p_in};
|
||||
std::array<const TWei*, 1> p_weis = {p_wei};
|
||||
std::array<const TOut*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
|
||||
p_weis,
|
||||
p_ds,
|
||||
p_out,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -22,9 +22,39 @@ struct SimpleDeviceMem
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
|
||||
}
|
||||
|
||||
// Delete copy operations (resource should not be copied)
|
||||
SimpleDeviceMem(const SimpleDeviceMem&) = delete;
|
||||
SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete;
|
||||
|
||||
// Define move operations
|
||||
SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_)
|
||||
{
|
||||
other.p_mem_ = nullptr;
|
||||
}
|
||||
|
||||
SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
if(p_mem_)
|
||||
{
|
||||
(void)hipFree(p_mem_);
|
||||
}
|
||||
p_mem_ = other.p_mem_;
|
||||
other.p_mem_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
~SimpleDeviceMem()
|
||||
{
|
||||
if(p_mem_)
|
||||
{
|
||||
(void)hipFree(p_mem_);
|
||||
}
|
||||
}
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
@@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src,
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper for parameter pack expansion (D tensors)
|
||||
template <typename ResultType, typename Op, typename DataType, std::size_t... Is>
|
||||
__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result,
|
||||
Op&& element_op,
|
||||
const DataType* const* tensor_ptrs,
|
||||
long_index_t element_offset,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
element_op(result, tensor_ptrs[Is][element_offset]...);
|
||||
}
|
||||
|
||||
// Generic helper for A and B tensors (works in all directions)
|
||||
template <index_t NumExtraTensors, typename DataType, typename ResultType, typename Op>
|
||||
__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result,
|
||||
Op&& element_op,
|
||||
const DataType* primary_ptr,
|
||||
const DataType* const* extra_ptrs,
|
||||
long_index_t extra_base_offset,
|
||||
long_index_t element_offset)
|
||||
{
|
||||
const DataType* tensor_ptrs[NumExtraTensors + 1];
|
||||
tensor_ptrs[0] = primary_ptr;
|
||||
|
||||
static_for<1, NumExtraTensors + 1, 1>{}(
|
||||
[&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; });
|
||||
|
||||
apply_multi_tensor_impl(result,
|
||||
element_op,
|
||||
tensor_ptrs,
|
||||
element_offset,
|
||||
std::make_index_sequence<NumExtraTensors + 1>{});
|
||||
}
|
||||
|
||||
// Helper for parameter pack expansion (D tensors)
|
||||
template <typename OutDataType, typename Op, std::size_t... Is>
|
||||
__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out,
|
||||
Op&& element_op,
|
||||
float computed_value,
|
||||
const float* d_values,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
float temp_out;
|
||||
element_op(temp_out, computed_value, d_values[Is]...);
|
||||
result_out = type_convert<OutDataType>(temp_out);
|
||||
}
|
||||
|
||||
// Specialized helper for D tensors with stride calculations and float conversion
|
||||
template <index_t NumDTensors, typename DDataType, typename OutDataType, typename Op>
|
||||
__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out,
|
||||
Op&& element_op,
|
||||
float computed_value,
|
||||
const DDataType* const* p_ds,
|
||||
const index_t* const* p_d_strides,
|
||||
index_t g,
|
||||
index_t n,
|
||||
index_t c_or_k,
|
||||
long_index_t spatial_linear_index)
|
||||
{
|
||||
if constexpr(NumDTensors == 0)
|
||||
{
|
||||
element_op(result_out, computed_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
float d_values[NumDTensors];
|
||||
|
||||
// Compute all D tensor indices and convert to float
|
||||
static_for<0, NumDTensors, 1>{}([&](auto i) {
|
||||
const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] +
|
||||
c_or_k * p_d_strides[i][2] + spatial_linear_index;
|
||||
d_values[i] = type_convert<float>(p_ds[i][d_idx]);
|
||||
});
|
||||
|
||||
apply_d_tensor_impl(result_out,
|
||||
element_op,
|
||||
computed_value,
|
||||
d_values,
|
||||
std::make_index_sequence<NumDTensors>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -376,7 +376,7 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#if defined(CK_USE_GFX950)
|
||||
constexpr auto _k_per_block = 32;
|
||||
#else
|
||||
constexpr auto _k_per_block = 16;
|
||||
|
||||
@@ -147,7 +147,7 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#if defined(CK_USE_GFX950)
|
||||
constexpr auto _k_per_block = 32;
|
||||
#else
|
||||
constexpr auto _k_per_block = 16;
|
||||
|
||||
@@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
@@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
@@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
@@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
@@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
@@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
|
||||
@@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
|
||||
@@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
@@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
|
||||
@@ -7,12 +7,17 @@
|
||||
|
||||
#include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp"
|
||||
#include "grouped_convolution_signatures.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
|
||||
|
||||
#include "ck_tile/builder/testing/filter_extent.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
// Temporary disable builder validate since we don't have deduced rtol, atol support
|
||||
#define ENABLE_BUILDER_VALIDATE 0
|
||||
|
||||
namespace ck_tile::builder::profiling {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
@@ -113,25 +118,66 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
ckt::run(ref_conv, args, inputs, reference.get());
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
|
||||
#if ENABLE_BUILDER_VALIDATE == 0
|
||||
using DataType =
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
|
||||
float,
|
||||
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP16,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bfloat16_t>>;
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
|
||||
const std::size_t output_bytes_num = conv_param.template GetOutputByte<DataType>();
|
||||
std::vector<DataType> out(output_bytes_num / sizeof(DataType));
|
||||
std::vector<DataType> ref(output_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
|
||||
const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(),
|
||||
conv_param.filter_spatial_lengths_.cend(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>()) *
|
||||
conv_param.C_;
|
||||
float max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<DataType, DataType, float>(GemmK);
|
||||
const auto atol =
|
||||
ck_tile::get_absolute_threshold<DataType, DataType, float>(max_accumulated_value, GemmK);
|
||||
#endif
|
||||
|
||||
[[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) {
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name
|
||||
<< std::endl;
|
||||
|
||||
#if ENABLE_BUILDER_VALIDATE
|
||||
const auto errors = ckt::validate(args, outputs, reference.get()).get_errors();
|
||||
for(const auto& error : errors)
|
||||
{
|
||||
valid = false;
|
||||
std::cout << "Number of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero() << std::endl;
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
}
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms,";
|
||||
#else
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost));
|
||||
valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol);
|
||||
#endif
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << " " << op_name << std::endl;
|
||||
}
|
||||
std::cout << " " << op_name << std::endl;
|
||||
};
|
||||
|
||||
if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "../../experimental/builder/test/impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
|
||||
namespace ck_tile::builder::profiling {
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
out_device_buf.ToDevice(output.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
// profile device Conv instances
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification == 1)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
|
||||
InDataType,
|
||||
@@ -154,6 +158,27 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
// GPU reference (compute once, compare in kernel loop)
|
||||
Tensor<InDataType> gpu_ref_input(in_g_n_c_wis_desc);
|
||||
if(do_verification == 2)
|
||||
{
|
||||
DeviceMem gpu_ref_in_dev(sizeof(InDataType) *
|
||||
input_device_result.mDesc.GetElementSpaceSize());
|
||||
gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization
|
||||
|
||||
ck::ref::naive_conv_bwd_data<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<InDataType*>(gpu_ref_in_dev.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data());
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
@@ -176,8 +201,6 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_kernel = 0;
|
||||
// profile device Conv instances
|
||||
bool pass = true;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
@@ -235,7 +258,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
in_device_buf.FromDevice(input_device_result.mData.data());
|
||||
|
||||
@@ -255,6 +278,31 @@ bool profile_conv_bwd_data_impl(int do_verification,
|
||||
show_data_nhwc_layout(input_host_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_device: ";
|
||||
show_data_nhwc_layout(input_device_result);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
in_device_buf.FromDevice(input_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(input_device_result, gpu_ref_input);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "in : ";
|
||||
show_data_nhwc_layout(output);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "wei: ";
|
||||
show_data_nhwc_layout(weight);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_gpu_ref : ";
|
||||
show_data_nhwc_layout(gpu_ref_input);
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "out_device: ";
|
||||
show_data_nhwc_layout(input_device_result);
|
||||
std::cout << std::endl;
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -107,8 +108,11 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
@@ -135,6 +139,24 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
// GPU reference (compute once, compare in kernel loop)
|
||||
Tensor<OutDataType> gpu_ref_output(out_g_n_k_wos_desc);
|
||||
if(do_verification == 2)
|
||||
{
|
||||
DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
|
||||
|
||||
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(gpu_ref_out_dev.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data());
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceConvFwd<NDimSpatial,
|
||||
InLayout,
|
||||
@@ -158,8 +180,6 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_kernel = 0;
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
@@ -217,7 +237,7 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
@@ -233,6 +253,23 @@ bool profile_conv_fwd_impl(int do_verification,
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(device_output, gpu_ref_output);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
// Calculate number of accumulations accounting for split_k
|
||||
const int num_accums =
|
||||
static_cast<int>(output.GetElementSize() / conv_param.K_ / split_k_value);
|
||||
|
||||
// Additional tolerance for split_k accumulation if needed
|
||||
int total_accums = num_accums;
|
||||
if(split_k_value > 1)
|
||||
{
|
||||
total_accums = std::max(num_accums, static_cast<int>(split_k_value));
|
||||
}
|
||||
|
||||
// Perform GPU verification (max value computed internally on GPU)
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Get maximum accumulated value from reference
|
||||
const std::size_t tensor_size =
|
||||
weight_device_result.mDesc.GetElementSpaceSize();
|
||||
max_accumulated_value =
|
||||
gpu_reduce_max<WeiDataType>(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size);
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
auto atol =
|
||||
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k,
|
||||
num_accums / num_accums_split_k);
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
num_accums_split_k);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
// Perform GPU verification
|
||||
auto gpu_result =
|
||||
ck::profiler::gpu_verify<WeiDataType, ComputeType, AccDataType>(
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_wei_buf.GetDeviceBuffer(),
|
||||
total_accums,
|
||||
tensor_size);
|
||||
ck::profiler::gpu_verify<WeiDataType>(wei_device_buf.GetDeviceBuffer(),
|
||||
gpu_ref_wei_buf.GetDeviceBuffer(),
|
||||
rtol,
|
||||
atol,
|
||||
tensor_size);
|
||||
|
||||
if(!gpu_result)
|
||||
{
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
bias_device_buf.ToDevice(bias.mData.data());
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
// CPU reference
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
@@ -190,6 +192,75 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU reference
|
||||
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
|
||||
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
|
||||
|
||||
d_lengths_vec[0] = conv_param.G_;
|
||||
d_lengths_vec[1] = conv_param.N_;
|
||||
d_lengths_vec[2] = conv_param.K_;
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
|
||||
}
|
||||
|
||||
if constexpr(BiasGK)
|
||||
{
|
||||
// For GK bias layout: G*K, zero strides for N and spatial dimensions
|
||||
d_strides_vec[0] = K;
|
||||
d_strides_vec[1] = 0;
|
||||
d_strides_vec[2] = 1;
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_strides_vec[3 + i] = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Full GNKHW layout - same as output
|
||||
ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin());
|
||||
}
|
||||
|
||||
std::array<const OutDataType*, 1> d_ptrs = {
|
||||
reinterpret_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer())};
|
||||
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
|
||||
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
|
||||
|
||||
std::array<const InDataType*, 1> in_ptrs = {
|
||||
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
|
||||
std::array<const WeiDataType*, 1> wei_ptrs = {
|
||||
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
|
||||
|
||||
ck::ref::naive_conv_fwd_multi_abd<0,
|
||||
0,
|
||||
1,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
OutDataType>( // Explicitly specify TD = OutDataType
|
||||
in_ptrs,
|
||||
wei_ptrs,
|
||||
d_ptrs,
|
||||
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
out_device_buf.FromDevice(host_output.mData.data());
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl(
|
||||
wei_device_buf.ToDevice(weight.mData.data());
|
||||
d_device_buf.ToDevice(d_tensor.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
// CPU reference
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
|
||||
NDimSpatial,
|
||||
InDataType,
|
||||
@@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl(
|
||||
host_output(idx) = ck::type_convert<OutDataType>(out_val);
|
||||
});
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU reference
|
||||
std::vector<ck::index_t> d_lengths_vec(NDimSpatial + 3);
|
||||
std::vector<ck::index_t> d_strides_vec(NDimSpatial + 3);
|
||||
|
||||
d_lengths_vec[0] = conv_param.G_;
|
||||
d_lengths_vec[1] = conv_param.N_;
|
||||
d_lengths_vec[2] = conv_param.K_;
|
||||
for(ck::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_lengths_vec[3 + i] = static_cast<ck::index_t>(conv_param.output_spatial_lengths_[i]);
|
||||
}
|
||||
|
||||
// D tensor has same layout as output
|
||||
ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin());
|
||||
|
||||
std::array<const DDataType*, 1> d_ptrs = {
|
||||
reinterpret_cast<const DDataType*>(d_device_buf.GetDeviceBuffer())};
|
||||
std::array<std::vector<ck::index_t>, 1> d_lengths = {d_lengths_vec};
|
||||
std::array<std::vector<ck::index_t>, 1> d_strides = {d_strides_vec};
|
||||
|
||||
std::array<const InDataType*, 1> in_ptrs = {
|
||||
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer())};
|
||||
std::array<const WeiDataType*, 1> wei_ptrs = {
|
||||
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer())};
|
||||
|
||||
ck::ref::naive_conv_fwd_multi_abd<0,
|
||||
0,
|
||||
1,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DDataType>( // Explicitly specify D tensor type
|
||||
in_ptrs,
|
||||
wei_ptrs,
|
||||
d_ptrs,
|
||||
reinterpret_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
bilinear_op);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
out_device_buf.FromDevice(host_output.mData.data());
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
@@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
|
||||
std::cout << "scale_out: " << scale_out << std::endl;
|
||||
|
||||
// run reference op
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
|
||||
std::cout << "\nVerifying algorithm against reference convolution..." << std::endl;
|
||||
@@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
|
||||
}
|
||||
});
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU reference
|
||||
// WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference
|
||||
// Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md)
|
||||
if constexpr(std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::Scale> &&
|
||||
std::is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
// Compute conv to CShuffleDataType (float), then post-process on CPU
|
||||
DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize());
|
||||
|
||||
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CShuffleDataType*>(gpu_ref_c_dev.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
PassThrough{});
|
||||
|
||||
ck::hip_check_error(hipDeviceSynchronize());
|
||||
|
||||
Tensor<CShuffleDataType> gpu_c(out_g_n_k_wos_desc);
|
||||
gpu_ref_c_dev.FromDevice(gpu_c.mData.data());
|
||||
|
||||
// Post-process on CPU to match CPU reference behavior
|
||||
host_output.ForEach([&](auto&, auto idx) {
|
||||
const auto conv_shuffle = ck::type_convert<CShuffleDataType>(gpu_c(idx));
|
||||
const auto conv_val = ck::type_convert<OutDataType>(conv_shuffle);
|
||||
out_element_op(host_output(idx), conv_val);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Normal path for non-int8 or non-Scale cases
|
||||
DeviceMem gpu_ref_out_dev(sizeof(OutDataType) *
|
||||
device_output.mDesc.GetElementSpaceSize());
|
||||
|
||||
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(gpu_ref_out_dev.GetDeviceBuffer()),
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
ck::hip_check_error(hipDeviceSynchronize());
|
||||
gpu_ref_out_dev.FromDevice(host_output.mData.data());
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
@@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
if(do_verification == 1)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
@@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification,
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
|
||||
pass =
|
||||
pass & ck::utils::check_err(device_output,
|
||||
host_output,
|
||||
"Error: Device and GPU ref results do not match!",
|
||||
get_rtol<OutDataType>(),
|
||||
get_atol<OutDataType>());
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "input : ", input.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "weight: ", weight.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "gpu_ref_output : ", host_output.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "device_output: ", device_output.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "profiler/grouped_convolution_forward_tile_algs.hpp"
|
||||
|
||||
|
||||
263
script/analyze_build/README.md
Normal file
263
script/analyze_build/README.md
Normal file
@@ -0,0 +1,263 @@
|
||||
# Build Trace Analysis
|
||||
|
||||
Simple to use, fast python tools for analyzing Clang `-ftime-trace` build performance data.
|
||||
|
||||
## Overview
|
||||
|
||||
We're kicking off a systematic effort to dramatically reduce CK and CK-Tile build times, [#3575](https://github.com/ROCm/composable_kernel/issues/3575). A key part of this work is improving our C++ metaprogramming to reduce the burden on the compiler.
|
||||
|
||||
In order to prioritize work and measure our progress, we need data on template instantiation. For single files, Clang's `-ftime-trace` build performance data is easy to analyze with the Perfetto UI. The problem we are solving here is how to analyze instantiation data across thousands of compilation units.
|
||||
|
||||
The python code in this directory provides helper functions to quickly load JSON files into pandas DataFrames that can be used for analysis in Jupyter notebooks.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
script/analyze_build/
|
||||
├── trace_analysis/ # Core library
|
||||
│ ├── __init__.py # Main exports
|
||||
│ ├── parse_file.py # Fast parsing of JSON trace files
|
||||
│ ├── template_analysis.py # Template instantiation analysis
|
||||
│ ├── template_parser.py # Template name parsing utilities
|
||||
│ └── phase_breakdown.py # Compilation phase breakdown
|
||||
├── notebooks/ # Jupyter notebooks for analysis
|
||||
│ └── file_analysis_example.ipynb # Template analysis example
|
||||
├── requirements.txt # Python dependencies
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Python Requirements
|
||||
|
||||
See `requirements.txt` for the complete list of dependencies:
|
||||
* **pandas** - DataFrame manipulation and analysis
|
||||
* **orjson** - Fast JSON parsing for trace files
|
||||
* **plotly** - Interactive visualizations (sunburst, treemap)
|
||||
* **nbformat** - Jupyter notebook format support
|
||||
* **ipykernel** - Kernel for running notebooks in VSCode/Jupyter
|
||||
* **kaleido** - Static image export from Plotly charts
|
||||
* **jupyter** - Full Jupyter environment
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Setup
|
||||
|
||||
1. Create a virtual environment (recommended):
|
||||
```bash
|
||||
cd script/analyze_build
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Install VSCode extensions if you want to run notebooks in VSCode:
|
||||
* Jupyter
|
||||
* Data Wrangler (interact with Pandas DataFrames)
|
||||
|
||||
### Analyzing a Single File
|
||||
|
||||
Use the `parse_file` function to load a `-ftime-trace` JSON file into a Pandas DataFrame:
|
||||
|
||||
```python
|
||||
from trace_analysis import parse_file
|
||||
|
||||
# Parse the trace file
|
||||
df = parse_file('path/to/trace.json')
|
||||
|
||||
# View basic info
|
||||
print(f"Total events: {len(df)}")
|
||||
print(df.columns)
|
||||
|
||||
# Analyze duration statistics
|
||||
print(df['dur'].describe())
|
||||
```
|
||||
|
||||
### Extracting Compilation Metadata
|
||||
|
||||
Get high-level metadata about the compilation:
|
||||
|
||||
```python
|
||||
from trace_analysis import get_metadata
|
||||
|
||||
# Extract metadata from trace file
|
||||
metadata = get_metadata('trace.json')
|
||||
|
||||
print(f"Source file: {metadata['source_file']}")
|
||||
print(f"Compilation time: {metadata['total_wall_time_s']:.2f}s")
|
||||
print(f"Started: {metadata['wall_start_datetime']}")
|
||||
print(f"Ended: {metadata['wall_end_datetime']}")
|
||||
```
|
||||
|
||||
The metadata includes:
|
||||
- `source_file`: Main .cpp/.c file being compiled
|
||||
- `time_granularity`: Time unit used ("microseconds")
|
||||
- `beginning_of_time`: Epoch timestamp in microseconds
|
||||
- `wall_start_time`: Wall clock start (microseconds since epoch)
|
||||
- `wall_end_time`: Wall clock end (microseconds since epoch)
|
||||
- `wall_start_datetime`: Human-readable start time
|
||||
- `wall_end_datetime`: Human-readable end time
|
||||
- `total_wall_time_us`: Total compilation time in microseconds
|
||||
- `total_wall_time_s`: Total compilation time in seconds
|
||||
|
||||
### Template Instantiation Analysis
|
||||
|
||||
The module includes specialized functions for analyzing C++ template instantiation costs:
|
||||
|
||||
```python
|
||||
from trace_analysis import (
|
||||
parse_file,
|
||||
get_template_instantiation_events,
|
||||
get_phase_breakdown,
|
||||
)
|
||||
|
||||
df = parse_file('trace.json')
|
||||
|
||||
# Get all template instantiation events with parsed template information
|
||||
template_events = get_template_instantiation_events(df)
|
||||
|
||||
# The returned DataFrame includes parsed columns:
|
||||
# - namespace: Top-level namespace (e.g., 'std', 'ck')
|
||||
# - template_name: Template name without parameters
|
||||
# - full_qualified_name: Full namespace::template_name
|
||||
# - param_count: Number of template parameters
|
||||
# - is_ck_type: Boolean indicating CK library types
|
||||
# - is_nested: Boolean indicating nested templates
|
||||
|
||||
# Find slowest template instantiations
|
||||
top_templates = template_events.nlargest(20, 'dur')
|
||||
print(top_templates[['template_name', 'namespace', 'param_count', 'dur']])
|
||||
|
||||
# Analyze by namespace
|
||||
namespace_summary = template_events.groupby('namespace').agg({
|
||||
'dur': ['count', 'sum', 'mean']
|
||||
})
|
||||
print(namespace_summary)
|
||||
```
|
||||
|
||||
### Compilation Phase Breakdown
|
||||
|
||||
Analyze how compilation time is distributed across different phases:
|
||||
|
||||
```python
|
||||
from trace_analysis import get_phase_breakdown, PhaseBreakdown
|
||||
|
||||
df = parse_file('trace.json')
|
||||
|
||||
# Get hierarchical phase breakdown
|
||||
breakdown = get_phase_breakdown(df)
|
||||
|
||||
# Display in Jupyter (automatic rich HTML display)
|
||||
display(breakdown)
|
||||
|
||||
# Print text representation
|
||||
print(breakdown)
|
||||
|
||||
# Access the underlying DataFrame
|
||||
print(breakdown.df)
|
||||
|
||||
# Convert to plotly format for visualization
|
||||
import plotly.express as px
|
||||
data = breakdown.to_plotly()
|
||||
fig = px.sunburst(**data)
|
||||
fig.show()
|
||||
```
|
||||
|
||||
The `PhaseBreakdown` class provides:
|
||||
- Hierarchical breakdown of compilation phases
|
||||
- Automatic calculation of "Other" residual time at each level
|
||||
- Validation that children don't exceed parent durations
|
||||
- Multiple output formats (text, DataFrame, Plotly)
|
||||
|
||||
## DataFrame Schema
|
||||
|
||||
The parsed DataFrame contains the following columns from the `-ftime-trace` format:
|
||||
|
||||
- `name`: Event name (function, template instantiation, etc.)
|
||||
- `ph`: Phase character ('X' for complete, 'B' for begin, 'E' for end, 'i' for instant)
|
||||
- `ts`: Timestamp in microseconds
|
||||
- `dur`: Duration in microseconds (for complete events)
|
||||
- `pid`: Process ID
|
||||
- `tid`: Thread ID
|
||||
- `arg_*`: Flattened arguments from the event's `args` field
|
||||
|
||||
### Template Event Columns
|
||||
|
||||
When using `get_template_instantiation_events()`, additional parsed columns are included:
|
||||
|
||||
- `namespace`: Top-level namespace extracted from the template name
|
||||
- `template_name`: Template name without namespace or parameters
|
||||
- `full_qualified_name`: Complete namespace::template_name
|
||||
- `param_count`: Number of template parameters
|
||||
- `is_ck_type`: Boolean flag for CK library types (namespace starts with 'ck')
|
||||
- `is_nested`: Boolean flag indicating nested template instantiations
|
||||
|
||||
## Use in Jupyter Notebooks
|
||||
|
||||
The module is designed to work seamlessly in Jupyter notebooks. See `notebooks/file_analysis_example.ipynb` for a complete example workflow that demonstrates:
|
||||
|
||||
- Loading and parsing trace files
|
||||
- Extracting compilation metadata
|
||||
- Analyzing phase breakdown with visualizations
|
||||
- Template instantiation analysis with parsed columns
|
||||
- Filtering and grouping by namespace
|
||||
- Identifying CK-specific template costs
|
||||
|
||||
To use in a notebook:
|
||||
|
||||
```python
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add trace_analysis to path
|
||||
sys.path.insert(0, str(Path.cwd().parent))
|
||||
|
||||
from trace_analysis import (
|
||||
parse_file,
|
||||
get_metadata,
|
||||
get_template_instantiation_events,
|
||||
get_phase_breakdown,
|
||||
)
|
||||
|
||||
# Load and analyze
|
||||
df = parse_file('path/to/trace.json')
|
||||
breakdown = get_phase_breakdown(df)
|
||||
templates = get_template_instantiation_events(df)
|
||||
|
||||
# Visualize
|
||||
import plotly.express as px
|
||||
fig = px.sunburst(**breakdown.to_plotly())
|
||||
fig.show()
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Core Functions
|
||||
|
||||
- `parse_file(filepath)`: Parse a `-ftime-trace` JSON file into a pandas DataFrame
|
||||
- `get_metadata(filepath_or_df)`: Extract compilation metadata from trace file or DataFrame
|
||||
|
||||
### Template Analysis
|
||||
|
||||
- `get_template_instantiation_events(df)`: Filter to template instantiation events with parsed template information
|
||||
|
||||
### Phase Breakdown
|
||||
|
||||
- `get_phase_breakdown(df)`: Generate hierarchical compilation phase breakdown
|
||||
- `PhaseBreakdown`: Class representing phase breakdown with multiple output formats
|
||||
|
||||
## Contributing
|
||||
|
||||
This is an experimental project for analyzing and improving C++ metaprogramming build times. Contributions are welcome! When adding new analysis functions:
|
||||
|
||||
1. Add the function to the appropriate module in `trace_analysis/`
|
||||
2. Export it in `__init__.py`
|
||||
3. Update this README with usage examples
|
||||
4. Consider adding a notebook example if the feature is substantial
|
||||
|
||||
## License
|
||||
|
||||
Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
SPDX-License-Identifier: MIT
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user