Merge branch 'develop' into amd-develop

This commit is contained in:
Jun Liu
2024-01-31 11:48:58 -08:00
78 changed files with 5662 additions and 967 deletions

2
.github/CODEOWNERS vendored
View File

@@ -1,4 +1,4 @@
* @zjing14 @asroy @junliume @illsilin @carlushuang
* @zjing14 @asroy @junliume @illsilin @carlushuang @aosewski
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney

View File

@@ -2,53 +2,66 @@
Full documentation for Composable Kernel is not yet available.
## (Unreleased) CK for ROCm 6.0.0
## (Unreleased) CK
### Fixes
- Fixed a hazard associated with inline v_dot (#808)
- Fixed two bugs in grouped convolution backward data without K padding (#848 #876)
None
### Optimizations
None
### Additions
- Added an image to a column kernel (#867)
- Added a column to an image kernel (#930)
- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
- Grouped convolution support for small K and C (#822 #879 #897)
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732)
- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108)
* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126)
### Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
None
## CK for ROCm 6.0.0
### Fixes
* Fixed a hazard associated with inline v_dot (#808)
* Fixed two bugs in grouped convolution backward data without K padding (#848 #876)
### Optimizations
None
### Additions
* Added an image to a column kernel (#867)
* Added a column to an image kernel (#930)
* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
* Grouped convolution support for small K and C (#822 #879 #897)
* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
* Support for Batched Gemm DL (#732)
### Changes
* Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
## CK 0.2.0 for ROCm 5.7.0
### Fixes
- Fixed a bug in 6-dimensional kernels (#555)
- Fixed a test case failure with grouped convolution backward weight (#524)
* Fixed a bug in 6-dimensional kernels (#555)
* Fixed a test case failure with grouped convolution backward weight (#524)
### Optimizations
- Improved the performance of the normalization kernel
* Improved the performance of the normalization kernel
### Additions
- New CMake flags:
- "DL_KERNELS"-- Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances
- "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types
- "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler
- New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler
- Support for MI300A/MI300X
- Support for AMD RDNA 3
- New user tutorial (#563)
- Additional instances for irregular GEMM sizes (#560)
- New inter-wave consumer-producer programming model for GEMM kernels (#310)
- GEMM with support multiple elementwise fusions (multi-D) (#534)
- Multi-embeddings support (#542)
- AMD RDNA 3 blockwise GEMM and real GEMM support (#541)
- AMD RDNA grouped convolution backward weight support (#505)
- MaxPool and AvgPool forward (#815); MaxPool backward (#750)
* New CMake flags:
* "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances
* "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types
* "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler
* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler
* Support for MI300A/MI300X
* Support for AMD RDNA 3
* New user tutorial (#563)
* Additional instances for irregular GEMM sizes (#560)
* New inter-wave consumer-producer programming model for GEMM kernels (#310)
* GEMM with support multiple elementwise fusions (multi-D) (#534)
* Multi-embeddings support (#542)
* AMD RDNA 3 blockwise GEMM and real GEMM support (#541)
* AMD RDNA grouped convolution backward weight support (#505)
* MaxPool and AvgPool forward (#815); MaxPool backward (#750)
### Changes
None

View File

@@ -122,7 +122,7 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'"
RUN sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
@@ -130,7 +130,7 @@ RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \

34
Jenkinsfile vendored
View File

@@ -84,7 +84,7 @@ def build_compiler(){
compiler = '/opt/rocm/bin/hipcc'
}
else{
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
compiler = "/llvm-project/build/bin/clang++"
}
else{
@@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
@@ -304,7 +304,7 @@ def buildHipClangJob(Map conf=[:]){
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 20, unit: 'HOURS')
timeout(time: 48, unit: 'HOURS')
{
cmake_build(conf)
}
@@ -348,7 +348,7 @@ def runCKProfiler(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
@@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
}
@@ -560,7 +560,7 @@ def Build_CK(Map conf=[:]){
sh """#!/bin/bash
mkdir -p build
ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install"
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install"
cmake --build build -- -j
"""
}
@@ -657,7 +657,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION=
0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : ""
pipeline {
@@ -680,7 +680,7 @@ pipeline {
string(
name: 'COMPILER_VERSION',
defaultValue: '',
description: 'Specify which version of compiler to use: release, amd-stg-open, amd-mainline-open, or leave blank (default).')
description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).')
string(
name: 'COMPILER_COMMIT',
defaultValue: '',
@@ -713,6 +713,10 @@ pipeline {
name: "RUN_CPPCHECK",
defaultValue: false,
description: "Run the cppcheck static analysis (default: OFF)")
booleanParam(
name: "RUN_PERFORMANCE_TESTS",
defaultValue: false,
description: "Run the performance tests (default: OFF)")
}
environment{
dbuser = "${dbuser}"
@@ -755,7 +759,11 @@ pipeline {
-o -iname \'*.cl\' \
| grep -v 'build/' \
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include --file-filter=*.cpp --enable=all --output-file=ck_cppcheck.log"
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \
-D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
-U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
}
steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
@@ -886,7 +894,7 @@ pipeline {
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() }
expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() }
}
options { retry(2) }
agent{ label rocmnode("gfx908 || gfx90a")}
@@ -902,7 +910,7 @@ pipeline {
{
when {
beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() }
expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() }
}
options { retry(2) }
agent{ label rocmnode("gfx90a")}
@@ -921,6 +929,10 @@ pipeline {
parallel
{
stage("Process results"){
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
}
agent { label 'mici' }
steps{
process_results()

View File

@@ -1,6 +1,9 @@
add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp)
target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations)
add_executable(client_layernorm2d_bwd_gamma_beta layernorm2d_bwd_gamma_beta.cpp)
target_link_libraries(client_layernorm2d_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations)
add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp)
target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations)

View File

@@ -0,0 +1,171 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp"
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = float;
using DBetaDataType = float;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t M = 1024;
ck::index_t N = 1024;
SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N);
SimpleDeviceMem x_dev(sizeof(XDataType) * M * N);
SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M);
SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M);
SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * N);
SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * N);
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::size_t num_bytes = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N +
sizeof(MeanInvStdDataType) * M * 2 + sizeof(DGammaDataType) * N +
sizeof(DBetaDataType) * N;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N}, // outLengths
{1}, // dgammaStrides
{1}, // dbetaStrides
{0}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N}, // outLengths
{1}, // dgammaStrides
{1}, // dbetaStrides
{0}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -1,5 +1,8 @@
add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp)
target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations)
add_executable(client_groupnorm_bwd_gamma_beta groupnorm_bwd_gamma_beta.cpp)
target_link_libraries(client_groupnorm_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations)
add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp)
target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations)

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp"
using DYDataType = float;
using XDataType = float;
using GammaDataType = float;
using MeanInvStdDataType = float;
using DGammaDataType = float;
using DBetaDataType = float;
constexpr int Rank = 5;
constexpr int NumReduceDim = 3;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t N = 32;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t G = 64;
ck::index_t C = 128;
std::size_t length = N * H * W * G * C;
std::vector<ck::index_t> strideDy = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> strideX = strideDy;
std::vector<ck::index_t> strideMeanInvStd = {G, 0, 0, 1, 0};
std::vector<ck::index_t> strideDGammaBeta = {C, 1};
SimpleDeviceMem dy_dev(sizeof(DYDataType) * length);
SimpleDeviceMem x_dev(sizeof(XDataType) * length);
SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G);
SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G);
SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * G * C);
SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * G * C);
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
std::size_t num_bytes = sizeof(DYDataType) * length + sizeof(XDataType) * length +
sizeof(GammaDataType) * G * C + sizeof(MeanInvStdDataType) * N * G * 2 +
sizeof(DGammaDataType) * G * C + sizeof(DBetaDataType) * G * C;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C},
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
{G, C},
strideDGammaBeta,
strideDGammaBeta,
{0, 1, 2}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float gb_per_sec = num_bytes / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
// run the best intance
if(found)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C},
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
{G, C},
strideDGammaBeta,
strideDGammaBeta,
{0, 1, 2}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -1,150 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
using DataType = int;
template <typename Desc>
void Print1d(const Desc& desc)
{
std::cout << "Print1d" << std::endl;
for(ck::index_t w = 0; w < desc.GetLength(I0); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " ";
}
std::cout << std::endl;
}
template <typename Desc>
void Print2d(const Desc& desc)
{
std::cout << "Print2d" << std::endl;
for(ck::index_t h = 0; h < desc.GetLength(I0); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I1); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " ";
}
std::cout << std::endl;
}
}
template <typename Desc>
void Print3dCustom(const Desc& desc)
{
std::cout << "Print3dCustom" << std::endl;
for(ck::index_t d = 0; d < desc.GetLength(I0); d++)
{
for(ck::index_t h = 0; h < desc.GetLength(I1); h++)
{
for(ck::index_t w = 0; w < desc.GetLength(I2); w++)
{
std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
int main()
{
// Tensor descriptor traverse in row-major (need to reverse dims)
std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl;
// Basic descriptor 0, 1, 2, ... 30, 31
// (dims:4,8 strides:1,4)
const auto desc_4x8_s1x4 =
ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}));
std::cout << "dims:4,8 strides:1,4" << std::endl;
Print2d(desc_4x8_s1x4);
using Cord1x1Type = ck::Tuple<ck::Number<1>, ck::Number<1>>;
constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{});
std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl;
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:4,(2,4) strides:2,(1,8)
const auto desc_4x2x4_s2x1x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8));
// Transform to 2d (column-major, need to to reverse dims)
const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor(
desc_4x2x4_s2x1x8,
ck::make_tuple(ck::make_pass_through_transform(4),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl;
Print2d(desc_4x2x4_s2x1x8_merged);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:(2,2),(2,4) strides:((1,4),(2,8)
const auto desc_2x2x2x4_s1x4x2x8 =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
// Transform to 2d
const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
// Transform to 3d
const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8,
ck::make_tuple(ck::make_pass_through_transform(2),
ck::make_pass_through_transform(2),
ck::make_merge_transform(ck::make_tuple(4, 2))),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl;
Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d);
// Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor)
// dims:((2,2),2),4 strides:((1,4),2),8
// Transform to 2d
const auto desc_2x2x2x4_s1x4x2x8_nested =
ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)),
ck::make_pass_through_transform(2),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))),
ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}),
ck::make_tuple(ck::Sequence<0>{}));
const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor(
desc_2x2x2x4_s1x4x2x8_nested_merged_3d,
ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)),
ck::make_pass_through_transform(4)),
ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}),
ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}));
std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl;
Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d);
Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d);
Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d);
return 0;
}

View File

@@ -1,4 +1,4 @@
add_executable(client_tensor_transform tensor_transform.cpp)
target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations)
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <initializer_list>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
static constexpr ck::index_t NumDimSpatial = 3;
using DataType = float;
using InputLayout = ck::tensor_layout::convolution::NDHWGC;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// Test copy from Global to Global through LDS and VGPR
template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape>
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
{
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
// Get local tiles for global memory
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
// Get partition per thread
const auto input_local_partition =
ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x);
auto output_local_partition =
ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x);
// Perform copy
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
constexpr ck::index_t scalar_per_vector = 4;
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(input_local_partition,
output_local_partition);
}
void PerformImageToColumnPad0(const ck::index_t G,
const ck::index_t N,
const ck::index_t Di,
const ck::index_t Hi,
const ck::index_t Wi,
const ck::index_t Do,
const ck::index_t Ho,
const ck::index_t Wo,
const ck::index_t C,
const ck::index_t Z,
const ck::index_t Y,
const ck::index_t X,
std::array<ck::index_t, NumDimSpatial> filter_strides,
std::array<ck::index_t, NumDimSpatial> filter_dilations)
{
const ck::index_t ZYXC = Z * Y * X * C;
const ck::index_t GC = G * C;
// shape: (G, (Wo, Ho, Do, N)), (C, X, Y, Z))
const auto shape = ck::make_tuple(ck::make_tuple(G, ck::make_tuple(Wo, Ho, Do, N)),
ck::make_tuple(C, X, Y, Z));
const auto in_strides =
ck::make_tuple(ck::make_tuple(C,
ck::make_tuple(filter_strides[2] * GC,
filter_strides[1] * Wi * GC,
filter_strides[0] * Hi * Wi * GC,
Di * Hi * Wi * GC)),
ck::make_tuple(1,
filter_dilations[2] * GC,
filter_dilations[1] * Wi * GC,
filter_dilations[0] * Hi * Wi * GC));
const auto in_layout = ck::wrapper::make_layout(shape, in_strides);
const auto out_strides = ck::make_tuple(
ck::make_tuple(
ZYXC,
ck::make_tuple(ZYXC * G, Wo * ZYXC * G, Ho * Wo * ZYXC * G, Do * Ho * Wo * ZYXC * G)),
ck::make_tuple(1, C, X * C, Y * X * C));
const auto out_layout = ck::wrapper::make_layout(shape, out_strides);
const ck::index_t input_size = N * Di * Hi * Wi * GC;
// Global memory buffers
SimpleDeviceMem in_buf(input_size * sizeof(DataType));
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
// User can choose appropriate number of threads and sizes per block
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
// This example doesn't support padding, user should select tile sizes
// which divides the shape completely
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
// Create buffers for global memory
auto input_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(in_buf.GetDeviceBuffer()), in_layout);
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
ck::wrapper::size<0>(tile_shape)) *
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
ck::wrapper::size<1>(tile_shape));
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
decltype(output_tensor_global),
decltype(tile_shape),
decltype(thread_layout)>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size),
dim3(ck::wrapper::size(thread_layout)),
0,
input_tensor_global,
output_tensor_global,
tile_shape,
thread_layout);
std::size_t num_btype = G * N * Do * Ho * Wo * ZYXC * 2 * sizeof(DataType);
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< std::endl;
}
int main(int argc, char* argv[])
{
constexpr ck::index_t G = 4; // number of groups
constexpr ck::index_t N = 32; // batch
constexpr ck::index_t C = 64; // input channel (per group)
constexpr ck::index_t Z = 3; // filter D
constexpr ck::index_t Y = 3; // filter H
constexpr ck::index_t X = 3; // filter W
constexpr ck::index_t Di = 9; // input D
constexpr ck::index_t Hi = 9; // input H
constexpr ck::index_t Wi = 7; // input W
constexpr ck::index_t Do = 7; // output D
constexpr ck::index_t Ho = 7; // output H
constexpr ck::index_t Wo = 5; // output W
PerformImageToColumnPad0(G,
N,
Di,
Hi,
Wi,
Do,
Ho,
Wo,
C,
Z,
Y,
X,
{1, 1, 1} /*filter_strides*/,
{1, 1, 1} /*filter_dilations*/);
return 0;
}

View File

@@ -1,2 +1,2 @@
rocm-docs-core==0.31.0
rocm-docs-core==0.33.0
sphinxcontrib-bibtex==2.6.2

View File

@@ -113,7 +113,7 @@ requests==2.31.0
# via
# pygithub
# sphinx
rocm-docs-core==0.31.0
rocm-docs-core==0.33.0
# via -r requirements.in
six==1.16.0
# via

View File

@@ -18,8 +18,7 @@ Description
The CK library provides a lightweight wrapper for more complex operations implemented in
the library. It allows indexing of nested layouts using a simple interface
(avoiding complex descriptor transformations) and memory access (using Tensor).
the library.
Example:
@@ -54,6 +53,11 @@ Output::
1 5 9 13 17 21 25 29
2 6 10 14 18 22 26 30
Advanced examples:
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
-------------------------------------
Layout
-------------------------------------

View File

@@ -19,6 +19,9 @@ add_custom_target(example_gemm_xdl)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include "common.hpp"
@@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#endif

View File

@@ -0,0 +1,51 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using F32 = float;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2<
ALayout, BLayout, CLayout,
F16, F16, F16, F32, F16,
PassThrough, PassThrough, PassThrough, GemmDefault,
2, 256,
256, 256,
32, 8, 4,
32, 32,
4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 4, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::LoopScheduler::Default, ck::PipelineVersion::v1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include "common.hpp"
@@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
#endif

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include "common.hpp"
@@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
#include "run_gemm_add_add_fastgelu_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); }
#endif

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#define BUILD_INT4_EXAMPLE
@@ -24,3 +22,4 @@ using RsDataType = ck::Tuple<R0DataType>;
#include "run_convnd_fwd_max_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); }
#endif

View File

@@ -272,15 +272,14 @@ int main(int argc, char* argv[])
{
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
for(int n = 0; n < N; ++n)
{
auto c_val =
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val);

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include "common.hpp"
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
#endif

View File

@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include <iostream>
#include <numeric>
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
#endif

View File

@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =

View File

@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off

View File

@@ -1,9 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#include <cstdlib>
#include <iostream>
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
#endif

View File

@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
throw std::runtime_error("Pool3d_fwd: problem with layout. ");
return {0, 0, 0, 0, 0};
};
template <typename TensorLayout>
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
throw std::runtime_error("Pool3d_fwd: problem with layout. ");
return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0});
};
template <typename DevicePoolFwdInstance,

View File

@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
throw std::runtime_error("Avgpool3d_bwd: problem with layout. ");
return {0, 0, 0, 0, 0};
};
template <typename TensorLayout>
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
throw std::runtime_error("Avgpool3d_bwd: problem with layout. ");
return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0});
};
template <typename DevicePoolBwdInstance,

View File

@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false;
int log_level_ = 0;
int cold_niters_ = 1;
int nrepeat_ = 10;
int cold_niters_ = 5;
int nrepeat_ = 50;
};

View File

@@ -0,0 +1,999 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
// Double LDS buffer
// Prefetech 2 stage
// Local prefetch 1 stage
namespace ck {
template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t ABufferLoadWidth,
index_t BBufferLoadWidth,
index_t ALDSWriteWidth,
index_t BLDSWriteWidth,
index_t ALDSReadWidth,
index_t BLDSReadWidth,
index_t MRepeat,
index_t NRepeat,
index_t MPerXDL,
index_t NPerXDL,
index_t KPerXDL>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
static constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
static constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
static constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
static constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
static constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
BlockSize,
WaveSize,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
}
};
template <
index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride =
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_pipeline_v4
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm =
XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
A_K1,
B_K1,
A_K1,
B_K1,
KPack,
KPack,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__
BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
// HotLoopInstList::Print();
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__device__ static constexpr auto HotLoopScheduler()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
;
constexpr auto num_buffer_load_inst =
HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num;
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
});
}
template <index_t stage>
__device__ static constexpr auto TailScheduler()
{
}
template <>
__device__ static constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst =
HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_ds_write_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA
});
}
template <>
__device__ static constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst =
HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto num_issue = num_ds_read_inst;
static_for<0, num_issue, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA
});
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
template <bool HasMainLoop,
index_t TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
// Inst List:
// ds_read_b128: 16
// ds_write_b128: 8
// buffer_load_dwordx4: 16
// v_mfma: 0
// -------------------------------------------------------------------------------------------
// Global prefetch 1th, Fill Ping LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Local prefetch 1th, Fill Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(I0));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(I0),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(I0));
});
});
});
// Global prefetch 2th, Fill Pong LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
// Global prefetch 3rd
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
// This hot loop has two legacy loopover, to implement the double local buffer strategy
do
{
// -------------------------------------------------------------------------------------------
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 3));
}
// tail
if constexpr(TailNum == 3)
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
using PongP2 = Number<0>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP2{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP2{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP2{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP2{}));
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == 2)
{
using PingP1 = Number<0>;
using PongP1 = Number<1>;
// MFMA: Ping Reg
// DS_WRITE: To Ping LDS
// DS_READ: Pong LDS to Pong Reg
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(PongP1{}),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_bufs(PongP1{}));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(PongP1{}),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_bufs(PongP1{}));
});
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
TailScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// -------------------------------------------------------------------------------------------
using PingP2 = Number<1>;
// MFMA: Pong Reg
// DS_WRITE: To Pong LDS
// DS_READ: Ping LDS to Ping Reg
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
// 64 v_mfma
__builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
}
}
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck

View File

@@ -0,0 +1,306 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGemm_Xdl_CShuffleV2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, true, 2>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemm_Xdl_CShuffleV2"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
auto group_id = block_1d_id % GroupNum;
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
index_t idx_N0 = remap_block_1d_id % N0;
index_t idx_M0 = remap_block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
// keep the redundant type argument for backward compatibility
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_N00_M0_N01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt&
operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8)
: M_(M), N_(N), N01_(N01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t N01 = 8)
: BlockToCTileMap_N00_M0_N01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_M0 = block_1d_id % M0;
index_t idx_N0 = block_1d_id / M0;
const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_;
index_t idx_N00 = idx_N0 / N01_;
index_t idx_N01 = idx_N0 % N01_;
index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return make_tuple(idx_M0_N01_local / N01_adapt,
idx_M0_N01_local % N01_adapt + idx_N00 * N01_);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t N01_;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>

File diff suppressed because it is too large Load Diff

View File

@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
{
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(

View File

@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int static err = 0;
template <typename T>
struct vector_type<T, 2>
{
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return data_.d128x1_;
}
else
{
return err;
}
}
};
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
template <typename X>
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return data_.d256x1_;
}
else
{
return err;
}
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<unsigned int>
{
static constexpr bool value = false;
};
template <>
struct is_known_at_compile_time<long_index_t>
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -14,22 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
* \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
*/
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
// Generate default idxs tuple (idx with all merged nested shapes)
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
__host__ __device__ constexpr static auto
GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
{
return generate_tuple(
[&](auto) {
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
@@ -43,11 +49,18 @@ struct Layout
Number<Tuple<Ts...>::Size()>{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template <typename Idx, typename... Ts>
__host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
__host__ __device__ constexpr static auto
GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
{
if constexpr(Idx::value == 0)
{
@@ -87,11 +100,17 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx)
@@ -126,6 +145,13 @@ struct Layout
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template <typename... ShapeDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
const DescriptorToMerge& desc)
@@ -137,18 +163,41 @@ struct Layout
const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
const auto upper_dims = make_tuple(Sequence<0>{});
// Merge to 1d
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return transform_tensor_descriptor(
desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
lower_dims,
upper_dims);
}
}
// Merge nested shape dims when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
__host__ __device__ constexpr static auto CreateMergedDescriptor(
const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
__host__ __device__ constexpr static auto
CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
[[maybe_unused]] const Tuple<IdxDims...>& idxs,
DescriptorToMerge& desc)
{
const auto transforms = generate_tuple(
[&](auto i) {
@@ -160,7 +209,17 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
return make_merge_transform(merge_elems);
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
return make_merge_transform(merge_elems);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return make_merge_transform_v1_carry_check(merge_elems);
}
}
else
{
@@ -185,14 +244,23 @@ struct Layout
}
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx,
const UnnestedDescriptorType& naive_descriptor)
const Tuple<IdxDims...>& idxs,
const UnrolledDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
@@ -208,19 +276,18 @@ struct Layout
static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
"Idx rank and Shape rank must be the same (except 1d).");
// Unroll while IdxDims is nested
const auto aligned_shape = AlignShapeToIdx(shape, idx);
const auto aligned_shape = AlignShapeToIdx(shape, idxs);
// Transform correct form of shape
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
}
}
using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
public:
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return unnested_descriptor_.GetElementSpaceSize();
return unrolled_descriptor_.GetElementSpaceSize();
}
__host__ __device__ Layout() = delete;
@@ -232,16 +299,15 @@ struct Layout
* \param unnested_descriptor Descriptor
*/
__host__ __device__ constexpr Layout(const Shape& shape,
const UnnestedDescriptorType& unnested_descriptor)
: shape_(shape)
const UnrolledDescriptorType& unnested_descriptor)
: unrolled_descriptor_(unnested_descriptor), shape_(shape)
{
// Construct if runtime mode
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
{
unnested_descriptor_ = unnested_descriptor;
descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
}
}
@@ -254,9 +320,9 @@ struct Layout
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
@@ -283,7 +349,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
@@ -350,29 +416,55 @@ struct Layout
}
/**
* \brief Get default descriptor (with the same size as Shape)
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \return Default descriptor.
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
__host__ __device__ constexpr const MergedNestsDescriptorType&
GetMergedNestingDescriptor() const
{
return merged_nests_descriptor_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return Flatten descriptor.
* \return 1D descriptor.
*/
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
__host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
{
return unnested_descriptor_;
return descriptor_1d_;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
{
return unrolled_descriptor_;
}
private:
UnnestedDescriptorType unnested_descriptor_;
// All dimensions are unrolled
UnrolledDescriptorType unrolled_descriptor_;
// 1D descriptor
Descriptor1dType descriptor_1d_;
// All nesting are merged
MergedNestsDescriptorType merged_nests_descriptor_;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const Shape shape_;
};

View File

@@ -1,16 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace wrapper {
/**
* \brief Perform generic copy between two tensors. Tensors must have the
* same size.
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
@@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename DimAccessOrderTuple,
index_t VectorDim,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType>
__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor();
const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor();
using SrcShapeType = remove_cvref_t<decltype(shape(src_tensor))>;
constexpr index_t num_dims = SrcShapeType::Size();
constexpr auto thread_slice_lengths =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform a copy between DynamicBuffers
auto transfer = ThreadwiseTensorSliceTransfer_v7<
Tuple<typename SrcTensorType::TensorElementType>,
Tuple<typename DstTensorType::TensorElementType>,
decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
tie(out_grid_desc),
tie(dst_tensor.GetBuffer()));
}
else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer)
{
// Perform copy from StaticBuffer to DynamicBuffer
const auto src_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v1r3<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
tensor_operation::element_wise::PassThrough,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
InMemoryDataOperationEnum::Set,
I1,
true>{out_grid_desc,
dst_tensor.GetMultiIdxOffsets(),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(in_grid_desc,
src_slice_origin_idxs,
src_tensor.GetBuffer(),
out_grid_desc,
dst_tensor.GetBuffer());
}
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_tensor.GetBuffer());
}
else
{
// Perform copy between StaticBuffers
copy(src_tensor, dst_tensor);
}
}
} // namespace wrapper
} // namespace ck

View File

@@ -10,189 +10,205 @@
namespace ck {
namespace wrapper {
namespace detail {
namespace {
/**
* \brief Check if Tuple contains Slice object
*
* \return True if tuple contains Slice object.
*/
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
{
return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
return (HasSlice(Ts{}) || ...);
}
/**
* \brief Calculate new shape after slice from parent shape.
*
* \param idxs Tuple of indexes defining slice ranges.
* \param shape Shape which will be sliced.
* \return New tensor shape.
*/
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
const SlicedShape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto new_shape = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
}
}
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// calculate new dimension
const auto& dim = size(shape.At(num_i));
const auto val = idxs.At(num_i).range(dim);
return make_tuple(val);
}
else
{
// remove dimension for just value
return Tuple<>{};
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_shape);
}
/**
* \brief Generate Freeze for each of nested shape.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be freezed.
* \return Generated freeze transforms.
*/
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
// dimension offset from idx
const auto dim = unrolled_shape.At(Number<i>{});
const auto dim_idx = idx % dim;
idx /= dim;
return make_freeze_transform(dim_idx);
},
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Generate transforms for slice tensor.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be sliced.
* \return Generated transforms.
*/
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
const Shape& shape)
{
// Pack each value in tuple to remove empty tuples after generation
auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
}
else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
const auto from = idx.At(num_i).from_;
const auto dim = size<num_i>(shape);
const auto range = idx.At(num_i).range(dim);
return make_slice_transform(range, from, from + range);
}
else
{
// remove dimension for just value
return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple(transforms);
}
template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
// There is no output for Freeze transform
return Sequence<>{};
}
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
return Sequence<i>{};
}
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
return Tuple<>{};
}
template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
constexpr auto num_transforms = Tuple<Transforms...>::Size();
// Deduce Sequence element for specific transform
const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
{
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
else
{
// Increase i if current_elem is Slice transform
const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(current_elem), next_tuple);
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace
} // namespace detail
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* The tensor is based on a descriptor stored in the Layout. Additionally,
* tensor can be sliced or shifted using multi-index offset.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam UnnestedDescriptorType Unnested descriptor (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
* \tparam UnrolledDescriptorType Flatten descriptor (layout component).
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors, // param for Register memory
index_t ScalarPerVector // param for Register memory
>
typename UnrolledDescriptorType>
struct Tensor
{
private:
// Check if Tuple contains Slice object
template <typename T>
__host__ __device__ constexpr static bool IsSlicing(T&&)
{
return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr static bool IsSlicing(Tuple<Ts...>&&)
{
return (IsSlicing(Ts{}) || ...);
}
// Calculate new tensor shape after slice
template <typename... Ts, typename ShapeTmpType>
__host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape) const
{
// Pack each value in tuple to remove empty tuples after generation
auto new_shape = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
if constexpr(!IsSlicing(tuple_element_t<i.value, Tuple<Ts...>>{}))
{
// if tuple does not have any slice then we can remove dimension
return Tuple<>{};
}
else
{
// if tuple then recurrence
return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i)));
}
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// calculate new dimension
const auto& dim = size(shape.At(num_i));
const auto val = idx.At(num_i).range(dim);
return make_tuple(val);
}
else
{
// remove dimension for just value
return Tuple<>{};
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple<0, 1>(new_shape);
}
// Generate Freeze for each of nested shape
template <typename T, typename ShapeTmpType>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx,
const ShapeTmpType& shape) const
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
// dimension offset from idx
const auto dim = unrolled_shape.At(Number<i>{});
const auto dim_idx = idx % dim;
idx /= dim;
return make_freeze_transform(dim_idx);
},
Number<decltype(unrolled_shape)::Size()>{});
}
template <typename... Ts, typename ShapeTmpType>
__host__ __device__ constexpr auto
GetTransformsFromSlicedTensor(const Tuple<Ts...>& idx, const ShapeTmpType& shape) const
{
// Pack each value in tuple to remove empty tuples after generation
auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i));
}
else if constexpr(is_detected<is_slice,
tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
const auto from = idx.At(num_i).from_;
const auto dim = shape.At(num_i);
const auto range = idx.At(num_i).range(dim);
return make_slice_transform(range, from, from + range);
}
else
{
// remove dimension for just value
return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
}
},
Number<Tuple<Ts...>::Size()>{});
// Remove empty tuples (deleted elements) and return
return UnrollNestedTuple(transforms);
}
// There is no output for Freeze transform
template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&) const
{
return Sequence<>{};
}
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto
GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&) const
{
return Sequence<i>{};
}
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const
{
return Tuple<>{};
}
template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto
GenerateUpperDims(const Tuple<Transforms...>& transforms) const
{
constexpr auto num_transforms = Tuple<Transforms...>::Size();
// Deduce Sequence element for specific transform
const auto currect_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
if constexpr(is_same_v<decltype(currect_elem), const Sequence<>>)
{
const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(currect_elem), next_tuple);
}
else
{
// Increase i if current_elem is Slice transform
const auto next_tuple =
GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
return concat_tuple(make_tuple(currect_elem), next_tuple);
}
}
template <typename... Ts, typename ShapeTmpType, typename FlattenDescriptor>
__host__ __device__ constexpr auto
GetDescriptorFromSlicedTensor(const Tuple<Ts...>& idx,
const ShapeTmpType& shape,
const FlattenDescriptor& flatten_desc) const
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
const auto transforms = GetTransformsFromSlicedTensor(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
public:
using ElementSpaceSize = decltype(Layout<Shape, UnnestedDescriptorType>{
Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using ElementSpaceSize = decltype(Layout<Shape, UnrolledDescriptorType>{
Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
using TensorElementType = ElementType; // DataType
static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
@@ -200,134 +216,207 @@ struct Tensor
BufferAddressSpace == MemoryTypeEnum ::Vgpr);
__host__ __device__ Tensor() = delete;
__host__ __device__ Tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
__host__ __device__ constexpr Tensor(ElementType* pointer,
const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize()))
buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ Tensor(const Layout<Shape, UnnestedDescriptorType>& layout)
: layout_(layout)
__host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
: layout_(layout),
multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
base_offset_(0)
{
static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
}
__host__ __device__ constexpr const Layout<Shape, UnnestedDescriptorType>& GetLayout() const
__host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
{
return layout_;
}
// Getter for new sliced tensor
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx) const
/**
* \brief Get the new sliced tensor.
*
* \param idx Tuple of indices: slice(from,to) or scalar.
* \return Sliced tensor.
*/
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator[](const Tuple<Ts...>& idx)
{
static_assert(IsDynamicBuffer, "Register slice is not supported");
const auto& shape = layout_.GetShape();
auto new_shape = GetShapeFromSlicedTensor(idx, shape);
auto new_shape = detail::GetSlicedShape(idx, shape);
const auto& flatten_desc = layout_.GetUnnestedDescriptor();
auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc);
const auto& flatten_desc = layout_.GetUnrolledDescriptor();
auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
const auto new_layout =
Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
// Update embed offset
base_offset_ -= new_layout(make_tuple(Number<0>{}));
return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
}
template <typename... Ts, enable_if_t<IsSlicing(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx) const
template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ auto operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<IsSlicing(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs) const
template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ auto operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the const value
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
/**
* \brief Getter of the tensor's const value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
const index_t offset = layout_(idx) + base_offset_;
return buffer_[offset];
}
else
{
constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
return buffer_[Number<offset>{}];
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Calculate and apply base offset in compile-time
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_[Number<index_offset + base_offset>{}];
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ const ElementType& operator()(Idxs... idxs) const
{
return this->operator[](make_tuple(idxs...));
}
// Getter for the value reference
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
/**
* \brief Getter of tensor value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
{
if constexpr(IsDynamicBuffer)
{
const index_t offset = layout_(idx);
const index_t offset = layout_(idx) + base_offset_;
return buffer_(offset);
}
else
{
constexpr index_t offset = Layout<Shape, UnnestedDescriptorType>{
constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnnestedDescriptorType{}}.template operator()<Tuple<Ts...>>();
return buffer_(Number<offset>{});
UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
// Apply embed offset (calculate in compiletime)
constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
Shape{},
UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
return buffer_(Number<index_offset + base_offset>{});
}
}
template <typename... Ts, enable_if_t<!IsSlicing(Tuple<Ts...>{}), bool> = false>
template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
__host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
{
return this->operator[](idx);
}
template <typename... Idxs, enable_if_t<!IsSlicing(Tuple<Idxs...>{}), bool> = false>
template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
__host__ __device__ ElementType& operator()(Idxs... idxs)
{
return this->operator[](make_tuple(idxs...));
}
__host__ __device__ constexpr auto GetDefaultDescriptor()
/**
* \brief Get descriptor with all nested dimensions merged.
*
* \return Merged nests descriptor.
*/
__host__ __device__ constexpr auto GetMergedNestingDescriptor()
{
return layout_.GetDefaultDescriptor();
return layout_.GetMergedNestingDescriptor();
}
/**
* \brief Get pointer to the data.
*
* \return Pointer.
*/
__host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }
__host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
__host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
/**
* \brief Get multi index offset to the data.
*
* \return Multi index offset.
*/
__host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }
/**
* \brief Apply multi index offset on the tensor.
*
* \param multi_idx_offset Multi index offset.
*/
template <typename MultiIdxOffsets>
__host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
{
multi_idx_offset_ = multi_idx_offset;
base_offset_ += layout_(multi_idx_offset);
}
private:
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
ElementType,
ElementSpaceSize,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType =
StaticBufferTupleOfVector<BufferAddressSpace,
ElementType,
NumVectors,
ScalarPerVector,
true /*InvalidElementUseNumericalZeroValue*/>;
using StaticBufferType = StaticBuffer<BufferAddressSpace,
ElementType,
size(Shape{}),
true /*InvalidElementUseNumericalZeroValue*/>;
// If register use static buffer, else use dynamic buffer
using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
const Layout<Shape, UnnestedDescriptorType> layout_;
const Layout<Shape, UnrolledDescriptorType> layout_;
Buffer buffer_;
// We use multi_idx_offset_ to enable the creation of a descriptor in
// compile time for partitions or tiles if tile shape and thread layout
// is known at compile time (We can use the same descriptor for each
// thread). Additionally, the copy between the static and dynamic buffer
// requires a descriptor known at compile time, so we can shift data using
// such multi_idx_offset_.
MultiIndex<Shape::Size()> multi_idx_offset_;
// Base offset and multi index offset are corresponding to exactly the
// same element in tensor ( and in physical memory ). Multi index offset
// is multi dimensional index. However base offset is calculated using
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t base_offset_;
};
} // namespace wrapper

View File

@@ -22,14 +22,19 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
// Generate packed (column-major) strides if not passed
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
@@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
Number<decltype(unrolled_shape)::Size()>{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
if constexpr(is_same_v<LayoutStrides, Tuple<>>)
@@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, strides));
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
}
/**
@@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnnestedDescriptorType>(shape, MakeFlattenDescriptor(shape, Tuple<>{}));
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
@@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
},
Number<old_shape_dims>{});
const auto& flatten_desc = layout.GetUnnestedDescriptor();
const auto& flatten_desc = layout.GetUnrolledDescriptor();
auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
}
@@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
@@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template <index_t idx, typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.template GetLength<idx>();
}
@@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnnestedDescriptorType>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
{
return layout.GetLengths();
}
@@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank([[maybe_unused]] const Layout<Shape, UnnestedDescriptorType>& layout)
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Shape::Size();
}
@@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t rank(const Number<IDim>&)
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
{
return 1;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__ __device__ constexpr index_t rank(const index_t&) { return 1; }
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
/**
* \brief Hierarchical rank.
@@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template <typename Shape, typename UnnestedDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnnestedDescriptorType>& layout)
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto& shape = layout.GetShape();
return TupleDepth(shape);
@@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t depth(const Number<IDim>&)
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
{
return 0;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__ __device__ constexpr index_t depth(const index_t&) { return 0; }
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
/**
* \brief Hierarchical depth.

View File

@@ -6,12 +6,22 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
namespace ck {
namespace wrapper {
namespace {
// Calculate shape for partition based on number of threads per each dim and
// previous shape
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
*
* \param shape Base tensor shape.
* \param thread_lengths Tuple of thread lengths.
* \return Partition shape.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths)
@@ -20,265 +30,165 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i));
}
else
{
const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i);
return slice_len;
}
},
Number<Tuple<Ts...>::Size()>{});
}
// Calculate shape for partition based on number of threads per each dim,
// previous strides and steps
template <typename... Ts, typename... Ls, typename... Steps, typename FlattenDescType>
__host__ __device__ constexpr auto
CalculateLocalPartitionDescriptor(const Tuple<Ts...>& shape,
const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const FlattenDescType& flatten_desc)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths);
const auto unrolled_shape = UnrollNestedTuple(shape);
constexpr auto dims = decltype(unrolled_thread_lengths)::Size();
using UnrolledStepsType = decltype(UnrollNestedTuple(steps));
using I1 = Number<1>;
const auto transforms = generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
// By default raked partition
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else if constexpr(!is_same_v<tuple_element_t<i.value, UnrolledStepsType>, index_t>)
{
// Compiletime partition
if constexpr(is_same_v<tuple_element_t<i.value, UnrolledStepsType>, I1>)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
else
{
// Runtime partition
if(steps.At(num_i) == 1)
{
// raked
const auto partition_stride = unrolled_thread_lengths.At(num_i);
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(partition_stride));
}
else
{
// packed
return make_embed_transform(make_tuple(unrolled_shape.At(num_i)),
make_tuple(I1{}));
}
}
},
Number<dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
const auto upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
index_t& thread_id)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ls...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), Tuple<>{}, thread_id);
}
else
{
return CalculateLayoutOffsetIdxImpl(
thread_lengths.At(num_i), steps.At(num_i), thread_id);
}
}
else
{
// Update thread_id after each dim
const auto dim_thread_id = thread_id % thread_lengths.At(num_i);
thread_id /= thread_lengths.At(num_i);
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return dim_thread_id;
}
else
{
// Apply step
return steps.At(num_i) * dim_thread_id;
}
}
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
return slice_len;
},
Number<Tuple<Ls...>::Size()>{});
}
// Convert integer thread_idx to tuple index with steps applied
template <typename... Ls, typename... Steps>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Ls...>& thread_lengths,
const Tuple<Steps...>& steps,
const index_t thread_id)
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape)
{
// Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates
index_t thread_id_copy = thread_id;
return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy);
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
Number<Tuple<Ls...>::Size()>{});
}
// Apply steps to index represented as tuple
template <typename... Steps, typename... Idxs>
__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple<Steps...>& steps,
const Tuple<Idxs...>& block_idxs)
/**
* \brief Calculate scaled offset for new partition/tile.
*
* \param thread_idxs Thread 1d id.
* \param partition_lengths_seq Sequence of partition shape.
* \param old_offset_idxs Multi index offset from base tensor to shift values.
* \return Partition shape.
*/
template <typename ThreadIdxs, typename PartitionLengthsSeq, typename OldOffsetIdxs>
__host__ __device__ constexpr auto
CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
const PartitionLengthsSeq& partition_lengths_seq,
const OldOffsetIdxs& old_offset_idxs)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Idxs...>>>::value)
{
// if tuple then recurrence
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i));
}
else
{
return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i));
}
}
else
{
if constexpr(is_same_v<Tuple<Steps...>, Tuple<>>)
{
return block_idxs.At(num_i);
}
else
{
// apply step
return steps.At(num_i) * block_idxs.At(num_i);
}
}
},
Number<Tuple<Idxs...>::Size()>{});
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
// User passes only shape per block to the make_local_tile function. This function calculates
// block layout based on the shape.
template <typename... Ts, typename... BlockDims>
__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple<Ts...>& shape,
const Tuple<BlockDims...>& tile_shape)
{
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
{
// if tuple then recurrence
return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i));
}
else
{
return shape.At(num_i) / tile_shape.At(num_i);
}
},
Number<Tuple<Ts...>::Size()>{});
}
} // namespace
/**
* \brief Create local partition for thread.
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \param steps Thread step (default=1, raked partition)
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id,
const StepsTuple steps = StepsTuple{})
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
{
// Create shape, strides and layout for new partition tensor
const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto partition_desc =
CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc);
const auto partition_layout = Layout<decltype(partition_shape), decltype(partition_desc)>(
partition_shape, partition_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id);
const auto partition_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + partition_offset,
partition_layout);
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
constexpr auto partition_shape =
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
// Create Thread Cluster Descriptor
constexpr auto partition_lengths_seq = generate_sequence_v2(
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
partition_shape, flatten_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
/**
* \brief Create local tile for thread block.
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Temporary to gain the best performance use 2d
* tile_shape.
*
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_idx Block index represented as tuple.
* \param steps Block step (default=1, raked partition)
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxTuple,
typename StepsTuple = Tuple<>>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxTuple& block_idx,
const StepsTuple steps = StepsTuple{})
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
{
// Create block lengths, strides and layout for new tile tensor
const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape);
// Create new descriptor and layout
const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor();
auto tile_desc =
CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc);
const auto tile_layout = Layout<remove_reference_t<decltype(tile_shape)>, decltype(tile_desc)>(
tile_shape, tile_desc);
// Calculate offset for new partition tensor
const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx);
const auto tile_offset = layout(tensor)(offset_idx);
return make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer() + tile_offset,
tile_layout);
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
if constexpr(BlockShapeTuple::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto tile_shape_seq =
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
Number<BlockShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return tile_tensor;
}
}
} // namespace wrapper

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -10,6 +10,7 @@
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
namespace ck {
namespace wrapper {
@@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template <typename Shape, typename UnnestedDescriptorType>
template <typename Shape, typename UnrolledDescriptorType>
struct Layout;
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors, // params for Register memory
index_t ScalarPerVector // param for Register memory
>
typename UnrolledDescriptorType>
struct Tensor;
template <typename FromType, typename ToType>
@@ -45,13 +42,22 @@ struct Slice
__host__ __device__ constexpr Slice() : from_(), to_() {}
__host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {}
/**
* \brief Calculate slice range.
*
* \param dim Dimension size.
* \return Slice range.
*/
template <typename T>
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
{
assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range");
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0)
{
return dim - from_ + to_ + 1;
@@ -101,40 +107,27 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
template <MemoryTypeEnum MemoryType,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType>
typename UnrolledDescriptorType>
constexpr auto make_tensor(ElementType* pointer,
const Layout<Shape, UnnestedDescriptorType>& layout)
const Layout<Shape, UnrolledDescriptorType>& layout)
{
return Tensor<MemoryType,
ElementType,
Shape,
UnnestedDescriptorType,
0 /*NumVectors*/,
0 /*ScalarPerVector*/>(pointer, layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(pointer, layout);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \return Constructed tensor.
*/
template <MemoryTypeEnum MemoryType,
index_t NumVectors,
index_t ScalarPerVector,
typename ElementType>
constexpr auto make_register_tensor()
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
{
const auto layout = make_layout(make_tuple(Number<NumVectors>{}), make_tuple(Number<1>{}));
return Tensor<MemoryType,
ElementType,
Tuple<Number<NumVectors>>,
std::remove_const_t<remove_reference_t<decltype(layout.GetUnnestedDescriptor())>>,
NumVectors,
ScalarPerVector>(layout);
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
}
/**
@@ -146,15 +139,9 @@ constexpr auto make_register_tensor()
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto& layout(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
layout(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return tensor.GetLayout();
}
@@ -170,15 +157,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto size(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
size(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return size<Idxs...>(tensor.GetLayout());
}
@@ -194,15 +175,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto rank(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
rank(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return rank<Idxs...>(tensor.GetLayout());
}
@@ -218,15 +193,9 @@ template <index_t... Idxs,
MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr auto
depth(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return depth<Idxs...>(tensor.GetLayout());
}
@@ -240,15 +209,9 @@ __host__ __device__ constexpr auto depth(const Tensor<BufferAddressSpace,
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnnestedDescriptorType,
index_t NumVectors,
index_t ScalarPerVector>
__host__ __device__ constexpr const auto& shape(const Tensor<BufferAddressSpace,
ElementType,
Shape,
UnnestedDescriptorType,
NumVectors,
ScalarPerVector>& tensor)
typename UnrolledDescriptorType>
__host__ __device__ constexpr const auto&
shape(const Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
return shape(tensor.GetLayout());
}

View File

@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return 0;
}
throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,

View File

@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return 0;
}
throw std::runtime_error(
"Conv_bwd_data: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,

View File

@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return 0;
}
throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,

View File

@@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator
return 0;
}
throw std::runtime_error("Conv_fwd: number of dimensions must be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,

View File

@@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
for(int k = 0; k < K; ++k)
{
ComputeTypeA v_a;
ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
@@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c;
CDataType v_c = 0;
arg.c_element_op_(v_c, v_acc);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck {
namespace tensor_operation {
@@ -229,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
return 0;
}
throw std::runtime_error("Img2Col: number of dimensions should be between 1 and 3.");
return 1;
}
float Run(const device::BaseArgument* p_arg,

View File

@@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory<
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// FP32
void add_device_groupnorm_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 5, 3>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>>
{
using DeviceOp = DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<MeanInvStdDataType, F32> && is_same_v<DGammaDataType, F32> &&
is_same_v<DBetaDataType, F32>)
{
add_device_groupnorm_bwd_gamma_beta_f32_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
// FP16
void add_device_layernorm2d_bwd_gamma_beta_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_layernorm2d_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DYDataType, F16> && is_same_v<XDataType, F16> &&
is_same_v<MeanInvStdDataType, F16> && is_same_v<DGammaDataType, F16> &&
is_same_v<DBetaDataType, F16>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_gamma_beta_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<MeanInvStdDataType, F32> && is_same_v<DGammaDataType, F32> &&
is_same_v<DBetaDataType, F32>)
{
if constexpr(Rank == 2 && NumReduceDim == 1)
{
add_device_layernorm2d_bwd_gamma_beta_f32_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,

View File

@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,

View File

@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,

View File

@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,

View File

@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
@@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
@@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance

View File

@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
@@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
@@ -112,6 +148,52 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance

View File

@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances(
void add_device_layernorm2d_bwd_gamma_beta_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{

View File

@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances(
void add_device_layernorm2d_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{

View File

@@ -298,7 +298,7 @@ int profile_gemm_impl(int do_verification,
}
}
return pass ? 0 : 1;
return pass;
}
} // namespace profiler

View File

@@ -145,7 +145,7 @@ bool profile_gemm_splitk_impl(int do_verification,
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 32, 36, 40, 64, 96, 128};
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38};
if(KBatch > 0)
{

View File

@@ -0,0 +1,261 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType>
bool profile_groupnorm_bwd_gamma_beta_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need GammaDataType and DXDataType here, just for reference class
using GammaDataType = DYDataType;
using DXDataType = DYDataType;
if(length.size() != 5)
return false;
index_t N = length[0];
index_t G = length[3];
index_t C = length[4];
std::vector<index_t> reduce_dim = {0, 1, 2};
std::vector<index_t> gamma_beta_length = {G, C};
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(gamma_beta_length); // dummy tensor, for reference
Tensor<MeanInvStdDataType> mean({N, G});
Tensor<MeanInvStdDataType> inv_std({N, G});
Tensor<DGammaDataType> dgamma(gamma_beta_length);
Tensor<DBetaDataType> dbeta(gamma_beta_length);
Tensor<DXDataType> host_dx(length); // dummy tensor, for reference
Tensor<DGammaDataType> host_dgamma(gamma_beta_length);
Tensor<DBetaDataType> host_dbeta(gamma_beta_length);
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideDGamma{dgamma.mDesc.GetStrides().begin(),
dgamma.mDesc.GetStrides().end()};
std::vector<index_t> strideDBeta{dbeta.mDesc.GetStrides().begin(),
dbeta.mDesc.GetStrides().end()};
std::vector<index_t> strideMeanInvStd = {G, 0, 0, 1, 0};
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dgamma.GenerateTensorValue(GeneratorTensor_1<DGammaDataType>{});
dbeta.GenerateTensorValue(GeneratorTensor_1<DBetaDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{0, 5});
dgamma.GenerateTensorValue(GeneratorTensor_2<DGammaDataType>{-5, 5});
dbeta.GenerateTensorValue(GeneratorTensor_2<DBetaDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{0, 0.5});
dgamma.GenerateTensorValue(GeneratorTensor_3<DGammaDataType>{-0.5, 0.5});
dbeta.GenerateTensorValue(GeneratorTensor_3<DBetaDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// add device normalization instances
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceGroupnormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) +
dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType);
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
gamma_beta_length,
strideDGamma,
strideDBeta,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
bool pass =
ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dgamma : ", host_dgamma.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank>
bool profile_layernorm_bwd_gamma_beta_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need GammaDataType and DXDataType here, just for reference class
using GammaDataType = DYDataType;
using DXDataType = DYDataType;
if(length.size() != Rank || Rank < 2)
return false;
// Assume normalize dimension for first dimension
// Layernorm 2D, input = [M, K], reduce on M axis
// Layernorm 4D, input = [N, H, W, C], redice on N axis
constexpr int NumReduceDim = Rank - 1;
std::vector<index_t> reduce_dim = {0};
std::vector<index_t> invarient_length{length.begin() + 1, length.end()};
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(invarient_length); // dummy tensor, for reference
Tensor<MeanInvStdDataType> mean({length[0]});
Tensor<MeanInvStdDataType> inv_std({length[0]});
Tensor<DGammaDataType> dgamma(invarient_length);
Tensor<DBetaDataType> dbeta(invarient_length);
Tensor<DXDataType> host_dx(length); // dummy tensor, for reference
Tensor<DGammaDataType> host_dgamma(invarient_length);
Tensor<DBetaDataType> host_dbeta(invarient_length);
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX = strideDy;
std::vector<index_t> strideDGamma{dgamma.mDesc.GetStrides().begin(),
dgamma.mDesc.GetStrides().end()};
std::vector<index_t> strideDBeta{dbeta.mDesc.GetStrides().begin(),
dbeta.mDesc.GetStrides().end()};
std::vector<index_t> strideMeanInvStd{Rank, 0};
strideMeanInvStd[0] = 1;
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dgamma.GenerateTensorValue(GeneratorTensor_1<DGammaDataType>{});
dbeta.GenerateTensorValue(GeneratorTensor_1<DBetaDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{0, 5});
dgamma.GenerateTensorValue(GeneratorTensor_2<DGammaDataType>{-5, 5});
dbeta.GenerateTensorValue(GeneratorTensor_2<DBetaDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{0, 0.5});
dgamma.GenerateTensorValue(GeneratorTensor_3<DGammaDataType>{-0.5, 0.5});
dbeta.GenerateTensorValue(GeneratorTensor_3<DBetaDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// add device normalization instances
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) +
dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType);
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
invarient_length,
strideDGamma,
strideDBeta,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
bool pass =
ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dgamma : ", host_dgamma.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck

View File

@@ -19,6 +19,8 @@ set(PROFILER_SOURCES
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
@@ -82,6 +84,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)

View File

@@ -137,8 +137,14 @@ int profile_gemm(int argc, char* argv[])
return pass ? 0 : 1;
};
if(false)
;
if(data_type != GemmDataType::F32_F32_F32 && data_type != GemmDataType::F16_F16_F16 &&
data_type != GemmDataType::BF16_BF16_BF16 && data_type != GemmDataType::INT8_INT8_INT8 &&
data_type != GemmDataType::F8_F8_F8)
{
// dummy clause before the else clauses for different data types
std::cout << "Gemm: this data_type is not implemented" << std::endl;
return 1;
}
#ifdef CK_ENABLE_FP32
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
@@ -231,7 +237,7 @@ int profile_gemm(int argc, char* argv[])
#endif
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
std::cout << "Gemm: this data_type & layout is not implemented" << std::endl;
return 1;
}

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct groupnormBwdGammaBetaArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_groupnorm_bwd_gamma_beta()
{
// eg: ckProfiler groupnorm_bwd_gamma_beta 1 0 2 0 1 --length 1 16 16 32 40
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1 16 16 32 40) \n"
<< std::endl;
}
int profile_groupnorm_bwd_gamma_beta(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_groupnorm_bwd_gamma_beta();
return 0;
}
groupnormBwdGammaBetaArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F32 = float;
if(length.size() == 5)
{
if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_groupnorm_bwd_gamma_beta_impl<F32, F32, F32, F32, F32, F32>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("length should be 5");
}
return 0;
}
REGISTER_PROFILER_OPERATION("groupnorm_bwd_gamma_beta",
"Group Normalization",
profile_groupnorm_bwd_gamma_beta);

View File

@@ -0,0 +1,112 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct layernormBwdGammaBetaArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_layernorm_bwd_gamma_beta()
{
// eg: ckProfiler layernorm_bwd_gamma_beta 0 0 2 0 1 --length 1502 4096
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n"
<< std::endl;
}
int profile_layernorm_bwd_gamma_beta(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_layernorm_bwd_gamma_beta();
return 0;
}
layernormBwdGammaBetaArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F16 = ck::half_t;
using F32 = float;
if(length.size() == 2)
{
constexpr int rank = 2;
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_bwd_gamma_beta_impl<F16, F16, F16, F32, F16, F16, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_layernorm_bwd_gamma_beta_impl<F32, F32, F32, F32, F32, F32, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("not implemented yet");
}
return 0;
}
REGISTER_PROFILER_OPERATION("layernorm_bwd_gamma_beta",
"Layer Normalization",
profile_layernorm_bwd_gamma_beta);

View File

@@ -1,2 +1,2 @@
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'

View File

@@ -140,6 +140,7 @@ add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization_fwd)
add_subdirectory(normalization_bwd_data)
add_subdirectory(normalization_bwd_gamma_beta)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)

View File

@@ -135,6 +135,8 @@ class TestConvTensorRearrangeInterface : public ::testing::Test
return col2img.IsSupportedArgument(argument);
}
throw std::runtime_error("Conv_tensor_rearrange: problem with tensor rearrange operator. ");
return 1;
}
};

View File

@@ -0,0 +1,13 @@
add_custom_target(test_normalization_bwd_gamma_beta)
add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance)
add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32)
endif()
add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp)
if(result EQUAL 0)
target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance)
add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32)
endif()

View File

@@ -0,0 +1,51 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestgroupnormBwdGammaBeta : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>;
using DGammaDataType = std::tuple_element_t<4, Tuple>;
using DBetaDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, H, W, G, C], reduce H, W, C
std::vector<std::vector<ck::index_t>> lengths = {{1, 1, 1, 1, 1},
{1, 2, 3, 4, 5},
{256, 9, 9, 9, 9},
{1, 64, 64, 32, 10},
{1, 32, 32, 32, 20},
{1, 16, 16, 32, 40}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_groupnorm_bwd_gamma_beta_impl<DYDataType,
XDataType,
MeanInvStdDataType,
ComputeDataType,
DGammaDataType,
DBetaDataType>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestgroupnormBwdGammaBeta, KernelTypes);
TYPED_TEST(TestgroupnormBwdGammaBeta, Test_FP32) { this->Run(); }

View File

@@ -0,0 +1,48 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestLayernorm2dBwdGammaBeta : public ::testing::Test
{
protected:
using DYDataType = std::tuple_element_t<0, Tuple>;
using XDataType = std::tuple_element_t<1, Tuple>;
using MeanInvStdDataType = std::tuple_element_t<2, Tuple>;
using ComputeDataType = std::tuple_element_t<3, Tuple>;
using DGammaDataType = std::tuple_element_t<4, Tuple>;
using DBetaDataType = std::tuple_element_t<5, Tuple>;
void Run()
{
// Bwd data: [N, D], reduce D
std::vector<std::vector<ck::index_t>> lengths = {
{4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}};
for(auto length : lengths)
{
bool success = ck::profiler::profile_layernorm_bwd_gamma_beta_impl<DYDataType,
XDataType,
MeanInvStdDataType,
ComputeDataType,
DGammaDataType,
DBetaDataType,
2>(
true, 2, false, false, length);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
// DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType>
std::tuple<F32, F32, F32, F32, F32, F32>>;
TYPED_TEST_SUITE(TestLayernorm2dBwdGammaBeta, KernelTypes);
TYPED_TEST(TestLayernorm2dBwdGammaBeta, Test_FP32) { this->Run(); }

View File

@@ -21,49 +21,59 @@ template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape,
typename LocalTileSteps,
typename LocalPartitionSteps>
bool UseOptimizedCopy>
__global__ void TestCopyDevice(const InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout,
const LocalTileSteps block_steps,
const LocalPartitionSteps thread_steps)
const ThreadLayoutShape thread_layout)
{
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
p_shared, ck::wrapper::make_layout(tile_shape));
const auto block_idxs = ck::make_tuple(ck::make_tuple(0, 0), blockIdx.x);
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
// Get local tiles for global memory
const auto input_local_tile =
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs, block_steps);
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
const auto output_local_tile =
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs, block_steps);
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
// Get partition per thread
const auto input_local_partition = ck::wrapper::make_local_partition(
input_local_tile, thread_layout, threadIdx.x, thread_steps);
const auto input_local_partition =
ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x);
auto lds_local_partition =
ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x, thread_steps);
auto output_local_partition = ck::wrapper::make_local_partition(
output_local_tile, thread_layout, threadIdx.x, thread_steps);
ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x);
auto output_local_partition =
ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x);
// Allocate VGPR
constexpr ck::index_t scalar_per_vector = 1;
constexpr ck::index_t vgpr_size = ck::wrapper::size(lds_local_partition);
auto tensor_vgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr,
vgpr_size,
scalar_per_vector,
ck::index_t>();
auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
layout(lds_local_partition));
// Perform copy
ck::wrapper::copy(input_local_partition, lds_local_partition);
ck::wrapper::copy(lds_local_partition, tensor_vgpr);
ck::wrapper::copy(tensor_vgpr, output_local_partition);
if constexpr(UseOptimizedCopy)
{
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>>;
constexpr ck::index_t vector_dim = 0;
constexpr ck::index_t scalar_per_vector = 2;
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(input_local_partition,
lds_local_partition);
// TODO: Enable optimized copy for static buffers
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(lds_local_partition,
tensor_vgpr);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(tensor_vgpr,
output_local_partition);
}
else
{
ck::wrapper::copy(input_local_partition, lds_local_partition);
ck::wrapper::copy(lds_local_partition, tensor_vgpr);
ck::wrapper::copy(tensor_vgpr, output_local_partition);
}
}
template <bool UseOptimizedCopy>
void PerformCopyGlobalToGlobalViaLDS()
{
const auto shape =
@@ -89,15 +99,8 @@ void PerformCopyGlobalToGlobalViaLDS()
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
const auto thread_layout =
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<32>{});
const auto tile_shape =
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<64>{});
const auto thread_steps =
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<2>{});
const auto block_steps =
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<64>{});
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const ck::index_t grid_size = ck::math::integer_divide_ceil(
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
@@ -106,8 +109,7 @@ void PerformCopyGlobalToGlobalViaLDS()
decltype(output_tensor_global),
decltype(tile_shape),
decltype(thread_layout),
decltype(block_steps),
decltype(thread_steps)>;
UseOptimizedCopy>;
launch_and_time_kernel(StreamConfig{},
kernel,
dim3(grid_size),
@@ -116,9 +118,7 @@ void PerformCopyGlobalToGlobalViaLDS()
input_tensor_global,
output_tensor_global,
tile_shape,
thread_layout,
block_steps,
thread_steps);
thread_layout);
// Verify results
std::vector<ck::index_t> output_data(ck::wrapper::size(shape));
@@ -126,4 +126,5 @@ void PerformCopyGlobalToGlobalViaLDS()
EXPECT_TRUE(ck::utils::check_err(output_data, input_data));
}
TEST(TestCopy, CopyGlobalToGlobalViaLDS) { PerformCopyGlobalToGlobalViaLDS(); }
TEST(TestCopyGlobalToGlobalViaLDS, GenericCopy) { PerformCopyGlobalToGlobalViaLDS<false>(); }
TEST(TestCopyGlobalToGlobalViaLDS, OptimizedCopy) { PerformCopyGlobalToGlobalViaLDS<true>(); }

View File

@@ -29,42 +29,29 @@ TEST(TestPartition, LocalPartition)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps =
ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<1>{}), ck::Number<1>{});
const auto thread_layout =
ck::make_tuple(ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}), ck::Number<1>{});
for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++)
{
const auto raked_partition =
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id);
const auto expected_partition_size =
ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout);
EXPECT_EQ(ck::wrapper::size(raked_partition), expected_partition_size);
EXPECT_EQ(raked_partition(0), thread_id);
}
const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{});
for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++)
{
const auto packed_partition =
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_steps);
ck::wrapper::make_local_partition(tensor, thread_layout, thread_id);
const auto expected_partition_size =
ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout);
const auto expected_partition_first_val = thread_id * ck::wrapper::size<0, 0>(thread_steps);
const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps);
const auto expected_partition_second_val = expected_partition_first_val + 1;
EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size);
EXPECT_EQ(packed_partition(0), expected_partition_first_val);
EXPECT_EQ(packed_partition(1), expected_partition_second_val);
}
}
TEST(TestPartition, LocalTile)
{
const auto shape =
ck::make_tuple(ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}), ck::Number<4>{});
const auto strides =
ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}), ck::Number<64>{});
const auto layout = ck::wrapper::make_layout(shape, strides);
const auto shape = ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}, ck::Number<4>{});
const auto strides = ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}, ck::Number<64>{});
const auto layout = ck::wrapper::make_layout(shape, strides);
std::vector<ck::index_t> data(ck::wrapper::size(layout));
std::iota(data.begin(), data.end(), 0);
@@ -72,48 +59,34 @@ TEST(TestPartition, LocalTile)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto block_steps =
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{});
const auto block_shape =
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{});
const auto block_layout =
ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{});
const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{});
const auto num_blocks =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks));
std::iota(block_idxs.begin(), block_idxs.end(), 0);
std::vector<ck::Tuple<ck::Tuple<ck::index_t, ck::index_t>, ck::index_t>> block_idxs;
for(ck::index_t x = 0; x < ck::wrapper::size<0, 0>(block_layout); x++)
for(auto block_idx : block_idxs)
{
for(ck::index_t y = 0; y < ck::wrapper::size<0, 1>(block_layout); y++)
{
for(ck::index_t z = 0; z < ck::wrapper::size<1>(block_layout); z++)
{
block_idxs.emplace_back(ck::make_tuple(x, y), z);
}
}
}
for(const auto& block_idx : block_idxs)
{
const auto raked_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx);
const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx);
const auto expected_tile_size = ck::wrapper::size(block_shape);
EXPECT_EQ(ck::wrapper::size(raked_tile), expected_tile_size);
EXPECT_EQ(raked_tile(0), layout(block_idx));
}
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides);
block_idx /= ck::wrapper::size<2>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides);
block_idx /= ck::wrapper::size<1>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
ck::wrapper::size<0>(block_shape) *
ck::wrapper::size<0>(strides);
for(const auto& block_idx : block_idxs)
{
const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_steps);
const auto expected_tile_size = ck::wrapper::size(block_shape);
const auto expected_tile_first_val =
ck::wrapper::size<0, 0>(block_idx) * ck::wrapper::size<0, 0>(block_shape) *
ck::wrapper::size<0, 0>(strides) +
ck::wrapper::size<0, 1>(block_idx) * ck::wrapper::size<0, 1>(block_shape) *
ck::wrapper::size<0, 1>(strides) +
ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides);
const auto expected_tile_second_val = expected_tile_first_val + 1;
EXPECT_EQ(ck::wrapper::size(packed_tile), expected_tile_size);
EXPECT_EQ(packed_tile(0), expected_tile_first_val);
EXPECT_EQ(packed_tile(1), expected_tile_second_val);
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
@@ -100,31 +100,26 @@ TEST(TestTensor, ReadWriteHostMemory)
__global__ void TestTensorReadWriteDevice(void* data, void* success)
{
constexpr ck::index_t nelems = 8;
constexpr ck::index_t scalar_per_vector = 1;
constexpr ck::index_t nelems = 8;
__shared__ ck::index_t p_shared[nelems];
ck::index_t* casted_data_ptr = static_cast<ck::index_t*>(data);
bool* casted_success_ptr = static_cast<bool*>(success);
const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2));
constexpr auto vgpr_layout =
ck::wrapper::make_layout(make_tuple(ck::Number<nelems>{}), make_tuple(ck::Number<1>{}));
auto tensor_global =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(casted_data_ptr, layout);
auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(p_shared, layout);
auto tensor_vgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr,
nelems,
scalar_per_vector,
ck::index_t>();
auto tensor_sgpr = ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Sgpr,
nelems,
scalar_per_vector,
ck::index_t>();
auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(p_shared, layout);
auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
vgpr_layout);
InitTensor(tensor_global);
InitTensor(tensor_lds);
StaticInitTensor<nelems>(tensor_vgpr);
StaticInitTensor<nelems>(tensor_sgpr);
*casted_success_ptr = TestTensorCheck1d(tensor_global);
*casted_success_ptr &= TestTensorCheck3d(tensor_global);
@@ -133,8 +128,6 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success)
*casted_success_ptr &= TestTensorCheck3d(tensor_lds);
*casted_success_ptr &= StaticTestTensorCheck1d<nelems>(tensor_vgpr);
*casted_success_ptr &= StaticTestTensorCheck1d<nelems>(tensor_sgpr);
}
TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory)