mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
merge down from the public repo
This commit is contained in:
12
.github/dependabot.yml
vendored
Normal file
12
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip" # See documentation for possible values
|
||||
directory: "/docs/.sphinx" # Location of package manifests
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "daily"
|
||||
41
Dockerfile
41
Dockerfile
@@ -1,30 +1,34 @@
|
||||
FROM ubuntu:20.04
|
||||
|
||||
ARG ROCMVERSION=5.3
|
||||
ARG compiler_version="release"
|
||||
ARG ROCMVERSION=5.6
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
|
||||
RUN set -xe
|
||||
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins
|
||||
RUN useradd -rm -d /home/manitera -s /bin/bash -u 1002 manitera
|
||||
# Add rocm repository
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y wget gnupg
|
||||
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
|
||||
RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
|
||||
RUN apt-get install -y wget gnupg curl
|
||||
RUN --mount=type=ssh if [ "$ROCMVERSION" != "5.6"]; then \
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
|
||||
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \
|
||||
else sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \
|
||||
apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \
|
||||
amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \
|
||||
DEBIAN_FRONTEND=noninteractive amdgpu-install -y --usecase=rocm ; \
|
||||
fi
|
||||
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
|
||||
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
|
||||
RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
apt-utils \
|
||||
build-essential \
|
||||
ccache \
|
||||
cmake-data \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
hip-rocclr \
|
||||
jq \
|
||||
@@ -45,6 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
rocm-device-libs \
|
||||
rocm-cmake \
|
||||
vim \
|
||||
nano \
|
||||
zlib1g-dev \
|
||||
openssh-server \
|
||||
clang-format-10 \
|
||||
@@ -52,6 +57,17 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
#Install latest version of cmake
|
||||
RUN apt purge --auto-remove -y cmake
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common lsb-release
|
||||
RUN apt clean all
|
||||
RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null
|
||||
RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main"
|
||||
RUN apt install -y kitware-archive-keyring
|
||||
RUN rm /etc/apt/trusted.gpg.d/kitware.gpg
|
||||
RUN apt install -y cmake
|
||||
|
||||
# Setup ubsan environment to printstacktrace
|
||||
RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
|
||||
ENV UBSAN_OPTIONS=print_stacktrace=1
|
||||
@@ -87,12 +103,7 @@ ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'"
|
||||
RUN sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ]; then \
|
||||
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/hip/bin/hipcc.pl && \
|
||||
sed -i '/$HIP_CLANG_TARGET = chomp($HIP_CLANG_TARGET);/c\ chomp($HIP_CLANG_TARGET);' /opt/rocm/bin/hipcc.pl; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" = "" ]; then \
|
||||
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
|
||||
@@ -100,7 +111,7 @@ RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_com
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=ssh if [ "$compiler_version" != "release" ] && [ "$compiler_commit" != "" ]; then \
|
||||
RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \
|
||||
|
||||
53
Jenkinsfile
vendored
53
Jenkinsfile
vendored
@@ -19,12 +19,33 @@ def runShell(String command){
|
||||
|
||||
def getDockerImageName(){
|
||||
def img
|
||||
if (params.COMPILER_COMMIT == ""){
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
|
||||
if (params.ROCMVERSION != "5.5" && params.ROCMVERSION != "5.6"){
|
||||
if (params.COMPILER_VERSION == "") {
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
else{
|
||||
if (params.COMPILER_COMMIT == ""){
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
|
||||
}
|
||||
else{
|
||||
def commit = "${params.COMPILER_COMMIT}"[0..6]
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
def commit = "${params.COMPILER_COMMIT}"[0..6]
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
|
||||
if (params.COMPILER_VERSION == "") {
|
||||
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
else{
|
||||
if (params.COMPILER_COMMIT == ""){
|
||||
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
|
||||
}
|
||||
else{
|
||||
def commit = "${params.COMPILER_COMMIT}"[0..6]
|
||||
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
|
||||
}
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
@@ -49,11 +70,11 @@ def build_compiler(){
|
||||
compiler = '/opt/rocm/bin/hipcc'
|
||||
}
|
||||
else{
|
||||
if (params.COMPILER_VERSION == "release"){
|
||||
compiler = "/opt/rocm/llvm/bin/clang++"
|
||||
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
else{
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
compiler = "/opt/rocm/llvm/bin/clang++"
|
||||
}
|
||||
}
|
||||
return compiler
|
||||
@@ -232,7 +253,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION != "release"){
|
||||
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
|
||||
@@ -287,7 +308,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION != "release"){
|
||||
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
|
||||
@@ -420,7 +441,7 @@ def Build_CK(Map conf=[:]){
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION != "release"){
|
||||
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
|
||||
@@ -586,16 +607,16 @@ pipeline {
|
||||
description: "Force building docker image (default: false), set to true if docker image needs to be updated.")
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '5.4.3',
|
||||
description: 'Specify which ROCM version to use: 5.4.3 (default).')
|
||||
defaultValue: '5.6',
|
||||
description: 'Specify which ROCM version to use: 5.6 (default).')
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: 'amd-stg-open',
|
||||
description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).')
|
||||
defaultValue: '',
|
||||
description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).')
|
||||
string(
|
||||
name: 'COMPILER_COMMIT',
|
||||
defaultValue: '5541927df00eabd6a110180170eca7785d436ee3',
|
||||
description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
|
||||
defaultValue: '',
|
||||
description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
|
||||
string(
|
||||
name: 'BUILD_COMPILER',
|
||||
defaultValue: 'hipcc',
|
||||
|
||||
@@ -73,7 +73,7 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
|
||||
@@ -203,4 +203,4 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
|
||||
@@ -206,4 +206,4 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C);
|
||||
SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K);
|
||||
|
||||
using DeviceOp =
|
||||
@@ -196,4 +196,4 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
2
client_example/18_groupnorm/CMakeLists.txt
Normal file
2
client_example/18_groupnorm/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(client_groupnorm_swish groupnorm_swish.cpp)
|
||||
target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations)
|
||||
169
client_example/18_groupnorm/groupnorm_swish.cpp
Normal file
169
client_example/18_groupnorm/groupnorm_swish.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = float;
|
||||
using BetaDataType = float;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::index_t N = 32;
|
||||
ck::index_t H = 16;
|
||||
ck::index_t W = 16;
|
||||
ck::index_t G = 64;
|
||||
ck::index_t C = 128;
|
||||
|
||||
std::size_t xy_size = N * H * W * G * C;
|
||||
std::size_t gamma_beta_size = G * C;
|
||||
|
||||
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
|
||||
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
|
||||
|
||||
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
|
||||
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
|
||||
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
|
||||
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Swish,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
|
||||
xy_strides, // xStrides
|
||||
gamma_beta_strides, // gammaStrides
|
||||
gamma_beta_strides, // betaStrides
|
||||
xy_strides, // yStrides
|
||||
{1, 2, 4}, // reduceDims
|
||||
1e-6,
|
||||
x_device_buf.GetDeviceBuffer(),
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
y_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
Swish{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
|
||||
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
|
||||
xy_strides, // xStrides
|
||||
gamma_beta_strides, // gammaStrides
|
||||
gamma_beta_strides, // betaStrides
|
||||
xy_strides, // yStrides
|
||||
{1, 2, 4}, // reduceDims
|
||||
1e-6,
|
||||
x_device_buf.GetDeviceBuffer(),
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
y_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
Swish{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -21,6 +21,7 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS
|
||||
-Wno-comma
|
||||
-Wno-old-style-cast
|
||||
-Wno-deprecated
|
||||
-Wno-unsafe-buffer-usage
|
||||
)
|
||||
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
git+https://github.com/RadeonOpenCompute/rocm-docs-core.git
|
||||
rocm-docs-core==0.2.0
|
||||
sphinxcontrib-bibtex==2.5.0
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile requirements.in
|
||||
# pip-compile .sphinx/requirements.in
|
||||
#
|
||||
accessible-pygments==0.0.4
|
||||
accessible-pygments==0.0.3
|
||||
# via pydata-sphinx-theme
|
||||
alabaster==0.7.13
|
||||
# via sphinx
|
||||
@@ -20,7 +20,7 @@ babel==2.12.1
|
||||
# sphinx
|
||||
backcall==0.2.0
|
||||
# via ipython
|
||||
beautifulsoup4==4.12.0
|
||||
beautifulsoup4==4.11.2
|
||||
# via pydata-sphinx-theme
|
||||
breathe==4.34.0
|
||||
# via rocm-docs-core
|
||||
@@ -34,7 +34,7 @@ click==8.1.3
|
||||
# via
|
||||
# jupyter-cache
|
||||
# sphinx-external-toc
|
||||
comm==0.1.3
|
||||
comm==0.1.2
|
||||
# via ipykernel
|
||||
debugpy==1.6.6
|
||||
# via ipykernel
|
||||
@@ -65,13 +65,11 @@ idna==3.4
|
||||
# via requests
|
||||
imagesize==1.4.1
|
||||
# via sphinx
|
||||
importlib-metadata==6.1.0
|
||||
importlib-metadata==6.0.0
|
||||
# via
|
||||
# jupyter-cache
|
||||
# myst-nb
|
||||
importlib-resources==5.10.4
|
||||
# via rocm-docs-core
|
||||
ipykernel==6.22.0
|
||||
ipykernel==6.21.3
|
||||
# via myst-nb
|
||||
ipython==8.11.0
|
||||
# via
|
||||
@@ -87,7 +85,7 @@ jsonschema==4.17.3
|
||||
# via nbformat
|
||||
jupyter-cache==0.5.0
|
||||
# via myst-nb
|
||||
jupyter-client==8.1.0
|
||||
jupyter-client==8.0.3
|
||||
# via
|
||||
# ipykernel
|
||||
# nbclient
|
||||
@@ -124,7 +122,7 @@ nbclient==0.5.13
|
||||
# via
|
||||
# jupyter-cache
|
||||
# myst-nb
|
||||
nbformat==5.8.0
|
||||
nbformat==5.7.3
|
||||
# via
|
||||
# jupyter-cache
|
||||
# myst-nb
|
||||
@@ -187,7 +185,7 @@ pyyaml==6.0
|
||||
# myst-parser
|
||||
# pybtex
|
||||
# sphinx-external-toc
|
||||
pyzmq==25.0.2
|
||||
pyzmq==25.0.1
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
@@ -195,8 +193,8 @@ requests==2.28.2
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core @ git+https://github.com/RadeonOpenCompute/rocm-docs-core.git
|
||||
# via -r requirements.in
|
||||
rocm-docs-core==0.2.0
|
||||
# via -r .sphinx/requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
# asttokens
|
||||
@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3
|
||||
sphinxcontrib-applehelp==1.0.4
|
||||
# via sphinx
|
||||
sphinxcontrib-bibtex==2.5.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# rocm-docs-core
|
||||
# via -r .sphinx/requirements.in
|
||||
sphinxcontrib-devhelp==1.0.2
|
||||
# via sphinx
|
||||
sphinxcontrib-htmlhelp==2.0.1
|
||||
@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3
|
||||
# via sphinx
|
||||
sphinxcontrib-serializinghtml==1.1.5
|
||||
# via sphinx
|
||||
sqlalchemy==1.4.47
|
||||
sqlalchemy==1.4.46
|
||||
# via jupyter-cache
|
||||
stack-data==0.6.2
|
||||
# via ipython
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_example_executable(example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
|
||||
|
||||
23
example/42_groupnorm/common.hpp
Normal file
23
example/42_groupnorm/common.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
|
||||
56
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
Normal file
56
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
Normal file
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
struct YElementOp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
|
||||
ck::is_same<T, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
T a;
|
||||
|
||||
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
|
||||
|
||||
y = x * a;
|
||||
};
|
||||
};
|
||||
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
YElementOp,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
1024, // BlockSize
|
||||
1, // ClusterM
|
||||
1024, // ClusterK
|
||||
1, // SliceM
|
||||
32, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
2, // SrcScalarPerVector
|
||||
1, // GammaVecDim (0=M, 1=K)
|
||||
2, // GammaScalarPerVector
|
||||
1, // BetaVecDim (0=M, 1=K)
|
||||
2, // BetaScalarPerVector
|
||||
2>; // OutScalarPerVector
|
||||
|
||||
#include "run_groupnorm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }
|
||||
40
example/42_groupnorm/groupnorm_swish_fp16.cpp
Normal file
40
example/42_groupnorm/groupnorm_swish_fp16.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using YElementOp = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
YElementOp,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
1024, // BlockSize
|
||||
1, // ClusterM
|
||||
1024, // ClusterK
|
||||
1, // SliceM
|
||||
32, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
2, // SrcScalarPerVector
|
||||
1, // GammaVecDim (0=M, 1=K)
|
||||
2, // GammaScalarPerVector
|
||||
1, // BetaVecDim (0=M, 1=K)
|
||||
2, // BetaScalarPerVector
|
||||
2>; // OutScalarPerVector
|
||||
|
||||
#include "run_groupnorm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); }
|
||||
@@ -1,80 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
|
||||
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
struct YElementOp
|
||||
int run_groupnorm_example(int argc, char* argv[])
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
|
||||
ck::is_same<T, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
T a;
|
||||
|
||||
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
|
||||
|
||||
y = x * a;
|
||||
};
|
||||
};
|
||||
|
||||
using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
YElementOp,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
1024, // BlockSize
|
||||
1, // ClusterM
|
||||
1024, // ClusterK
|
||||
1, // SliceM
|
||||
32, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
2, // SrcScalarPerVector
|
||||
1, // GammaVecDim (0=M, 1=K)
|
||||
2, // GammaScalarPerVector
|
||||
1, // BetaVecDim (0=M, 1=K)
|
||||
2, // BetaScalarPerVector
|
||||
2>; // OutScalarPerVector
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::index_t N = 2;
|
||||
ck::index_t H = 32;
|
||||
ck::index_t W = 32;
|
||||
ck::index_t G = 32;
|
||||
ck::index_t C = 30;
|
||||
ck::index_t N = 32;
|
||||
ck::index_t H = 16;
|
||||
ck::index_t W = 16;
|
||||
ck::index_t G = 64;
|
||||
ck::index_t C = 128;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -172,6 +172,11 @@
|
||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
// denorm test fix, required to work around dissue
|
||||
#ifndef CK_WORKAROUND_DENORM_FIX
|
||||
#define CK_WORKAROUND_DENORM_FIX 0
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
|
||||
@@ -73,18 +73,157 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto
|
||||
MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad)
|
||||
{
|
||||
assert(KPad % (K1 * KBatch) == 0);
|
||||
|
||||
const index_t K0 = KPad / (K1 * KBatch);
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_right_pad_transform(M, PadM)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad)
|
||||
{
|
||||
assert(KPad % (K1 * KBatch) == 0);
|
||||
|
||||
const index_t K0 = KPad / (K1 * KBatch);
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto GetKPad(index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
|
||||
const index_t KPad = KBatch * K0 * K1;
|
||||
return KPad;
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
@@ -114,64 +253,236 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
// GridwiseGemm
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t k_batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
k_batch_{k_batch}
|
||||
{
|
||||
int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ =
|
||||
DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
M, K, StrideA, k_batch_, KPad);
|
||||
b_grid_desc_kbatch_k0_n_k1_ =
|
||||
DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1(
|
||||
K, N, StrideB, k_batch_, KPad);
|
||||
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceGemmXdlSplitKCShuffle::Argument;
|
||||
|
||||
void Print(const Argument& karg) { karg.Print(); }
|
||||
void Print(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
Print(karg);
|
||||
Print(arg);
|
||||
}
|
||||
|
||||
const auto kbatch = karg.k_batch;
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
|
||||
"setting");
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
|
||||
const auto K0 = karg.K0;
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(kbatch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -180,19 +491,37 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set>;
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd>;
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdlSplitKCShuffle::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
@@ -215,9 +544,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -235,9 +567,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
@@ -249,10 +581,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
@@ -268,9 +601,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -282,10 +615,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
GridwiseGemm::CalculateMPadded(M),
|
||||
GridwiseGemm::CalculateNPadded(N),
|
||||
GridwiseGemm::CalculateKPadded(K),
|
||||
GridwiseGemm::CalculateK0(K, KBatch),
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
@@ -296,7 +630,31 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); }
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlSplitKCShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -316,8 +316,6 @@ struct Sigmoid
|
||||
|
||||
y = 1 / (ck::type_convert<T>(1) + exp(-x));
|
||||
};
|
||||
|
||||
int32_t divider_ = 1;
|
||||
};
|
||||
|
||||
struct TanH
|
||||
@@ -333,6 +331,23 @@ struct TanH
|
||||
};
|
||||
};
|
||||
|
||||
struct Swish
|
||||
{
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
|
||||
};
|
||||
|
||||
float beta_ = 1.0f;
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -505,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
|
||||
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// ABDataTypeAdjusted -> ABDataType throughout this file
|
||||
#if defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
using ABDataTypeAdjusted =
|
||||
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
|
||||
#else
|
||||
|
||||
@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
@@ -136,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
// when mfma if fixed, remove this section and update
|
||||
// FloatABAdjusted -> FloatAB throughout this file
|
||||
#if defined(__gfx90a__)
|
||||
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
|
||||
#else
|
||||
using FloatABAdjusted = FloatAB;
|
||||
|
||||
@@ -18,24 +18,61 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CBlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
|
||||
kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
|
||||
karg, static_cast<void*>(p_shared));
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
static_cast<void*>(p_shared_block),
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c_block_cluster_adaptor);
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c_block_cluster_adaptor;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -43,13 +80,13 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CMNGridDesc,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
@@ -90,238 +127,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
const FloatAB* p_a_grid;
|
||||
const FloatAB* p_b_grid;
|
||||
FloatC* p_c_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KPadded;
|
||||
index_t K0;
|
||||
index_t k_batch;
|
||||
|
||||
Argument(const FloatAB* p_a_grid_,
|
||||
const FloatAB* p_b_grid_,
|
||||
FloatC* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t MPadded_,
|
||||
index_t NPadded_,
|
||||
index_t KPadded_,
|
||||
index_t K0_,
|
||||
index_t k_batch_)
|
||||
: p_a_grid(p_a_grid_),
|
||||
p_b_grid(p_b_grid_),
|
||||
p_c_grid(p_c_grid_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
StrideA(StrideA_),
|
||||
StrideB(StrideB_),
|
||||
StrideC(StrideC_),
|
||||
MPadded(MPadded_),
|
||||
NPadded(NPadded_),
|
||||
KPadded(KPadded_),
|
||||
K0(K0_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KP:" << KPadded << ", "
|
||||
<< "K0:" << K0 << ", "
|
||||
<< "KB:" << k_batch << "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static auto CalculateGridSize(const Argument& karg)
|
||||
{
|
||||
return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
|
||||
math::integer_divide_ceil(karg.M, MPerBlock),
|
||||
karg.k_batch);
|
||||
}
|
||||
|
||||
// prefer this to be called on host
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
// k_batch * k0 * k0_per_block * k1
|
||||
auto K_t = K_Batch * K0PerBlock * K1;
|
||||
return (K + K_t - 1) / K_t * K0PerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K0 = CalculateK0(K, K_Batch);
|
||||
return K_Batch * K0 * K1;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
|
||||
index_t MPad,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t KBatch,
|
||||
index_t K0,
|
||||
index_t KPad)
|
||||
{
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_kpad,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
|
||||
index_t NPad,
|
||||
index_t N,
|
||||
index_t StrideB,
|
||||
index_t KBatch,
|
||||
index_t K0,
|
||||
index_t KPad)
|
||||
{
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_kpad_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
|
||||
{
|
||||
return transform_tensor_descriptor(c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -370,68 +179,45 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerXDL)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
|
||||
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
|
||||
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
|
||||
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
|
||||
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
return false;
|
||||
}
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
return false;
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
|
||||
const index_t KPad = KBatch * K0 * K1;
|
||||
return KPad;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k0_block_loop = K0 > K0PerBlock;
|
||||
@@ -439,9 +225,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
return has_main_k0_block_loop;
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
@@ -458,11 +243,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
|
||||
const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
|
||||
{
|
||||
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
|
||||
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
|
||||
c_m_n_grid_desc, 8, KBatch);
|
||||
}
|
||||
|
||||
@@ -479,25 +263,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
|
||||
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared_block,
|
||||
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const FloatAB* p_a_grid = karg.p_a_grid;
|
||||
const FloatAB* p_b_grid = karg.p_b_grid;
|
||||
FloatC* p_c_grid = karg.p_c_grid;
|
||||
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
|
||||
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
|
||||
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
|
||||
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
|
||||
const auto c_grid_desc_m_n =
|
||||
MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
const AElementwiseOperation a_element_op = AElementwiseOperation{};
|
||||
const BElementwiseOperation b_element_op = BElementwiseOperation{};
|
||||
const CElementwiseOperation c_element_op = CElementwiseOperation{};
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -507,16 +290,26 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!c_block_cluster_adaptor.ValidCTileIndex(
|
||||
make_tuple(block_work_idx[I1], block_work_idx[I2]),
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
@@ -652,6 +445,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
@@ -854,7 +648,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{c_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_m_id, 0, block_n_id, 0),
|
||||
make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0),
|
||||
c_element_op};
|
||||
|
||||
constexpr auto mxdlperwave_forward_step =
|
||||
@@ -923,48 +717,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
struct LStr
|
||||
{
|
||||
static std::string Get() { return ""; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LStr<ck::tensor_layout::gemm::RowMajor>
|
||||
{
|
||||
static std::string Get() { return "R"; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
|
||||
{
|
||||
static std::string Get() { return "C"; }
|
||||
};
|
||||
|
||||
static std::string GetTypeString()
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "GemmXdlSplitKCShuffle_"
|
||||
<< getGemmSpecializationString(GemmSpec) << "_"
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< "_"
|
||||
<< "B" << BlockSize << "_"
|
||||
<< "Vec" << ABlockTransferSrcScalarPerVector << "x"
|
||||
<< BBlockTransferSrcScalarPerVector << "x"
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
|
||||
<< MPerBlock << "x"
|
||||
<< NPerBlock << "x"
|
||||
<< K0PerBlock << "x"
|
||||
<< K1 ;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -168,6 +168,10 @@ __device__ double exp<double>(double x)
|
||||
return exp(x);
|
||||
}
|
||||
|
||||
static inline __host__ float exp(float x) { return std::expf(x); }
|
||||
|
||||
static inline __host__ double exp(double x) { return std::exp(x); }
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
|
||||
@@ -96,6 +96,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&);
|
||||
|
||||
// FP32
|
||||
void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
|
||||
|
||||
// [x, gamma, beta, y] = [f16, f32, f32, f16]
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&);
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
F32,
|
||||
YDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>>
|
||||
{
|
||||
using DeviceOp = DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
F32,
|
||||
YDataType,
|
||||
ck::tensor_operation::element_wise::Swish,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
|
||||
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
|
||||
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
|
||||
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F16>)
|
||||
{
|
||||
if constexpr(Rank == 5 && NumReduceDim == 3)
|
||||
{
|
||||
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -26,8 +26,7 @@ using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
|
||||
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
@@ -36,22 +35,14 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
add_instance_library(device_normalization_instance
|
||||
device_normalization_f16_instance.cpp
|
||||
device_normalization_f32_instance.cpp
|
||||
device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm2d_f32_instance.cpp
|
||||
device_layernorm4d_f16_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f16_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_5_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F32, F16, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_normalization_f16_f32_f32_f16_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
|
||||
void add_device_normalization_rank_5_3_swish_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Swish, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "normalization_instance_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
void add_device_normalization_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f32_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,70 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
// clang-format off
|
||||
using device_normalization_f16_instances =
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
void add_device_normalization_rank_2_1_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
void add_device_normalization_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
void add_device_normalization_rank_5_3_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,69 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_layernorm_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_normalization_rank_2_1_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 2, 1>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 2, 1>{});
|
||||
}
|
||||
|
||||
void add_device_normalization_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 4, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 4, 3>{});
|
||||
}
|
||||
|
||||
void add_device_normalization_rank_5_3_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 5, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,101 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_instances =
|
||||
// clang-format off
|
||||
std::tuple <
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 64, 1, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
|
||||
DeviceNormalizationImpl<F16, F16, F16, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 16, 1, 8, 1, 8, 1, 8, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F32, F32, F32, F32, F32, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename OutElementwise, index_t Rank, index_t Reduce>
|
||||
using device_normalization_f16_f32_f32_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 1, 1, 1, 1, 1, 1, 1, 1>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 2, 1, 2, 1, 2, 1, 2, 2>, // irregular size
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 128, 1, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 2, 16, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 512, 1, 512, 2, 8, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 4, 1, 4, 1, 4, 1, 4, 4>,
|
||||
DeviceNormalizationImpl<F16, F32, F32, F32, F16, OutElementwise, Rank, Reduce, 1024, 1, 1024, 1, 8, 1, 4, 1, 4, 1, 4, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -190,9 +190,9 @@ bool profile_groupnorm_impl(int do_verification,
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
LogRange(std::cout << "length = ", length, ",") << ", ";
|
||||
std::cout << "num_kernel = " << num_kernel << ", best perf = " << best_avg_time << " ms, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
|
||||
LogRange(std::cout << "length = ", length, ",") << std::endl;
|
||||
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_instance_name << std::endl;
|
||||
}
|
||||
|
||||
if(num_kernel == 0)
|
||||
|
||||
Reference in New Issue
Block a user