mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Merge branch 'develop' into mi300
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -48,3 +48,6 @@ build*
|
||||
.gdb_history
|
||||
install.dir*
|
||||
|
||||
# directories containing generated documentation
|
||||
docs/source/_build/
|
||||
docs/docBin/
|
||||
|
||||
24
CHANGELOG.md
Normal file
24
CHANGELOG.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Change Log for Composable Kernel
|
||||
|
||||
Full documentation for Composable Kernel is not yet available.
|
||||
|
||||
## CK 0.1.1 for ROCm 5.5.0
|
||||
|
||||
### Fixed
|
||||
- Fixed a bug in 6-dimensional kernels (#555).
|
||||
- Fixed grouped ConvBwdWeight test case failure (#524).
|
||||
|
||||
### Optimizations
|
||||
- Improve proformance of normalization kernel
|
||||
|
||||
### Added
|
||||
- Added user tutorial (#563).
|
||||
- Added more instances for irregular GEMM sizes (#560).
|
||||
- Added inter-wave consumer-producer programming model for GEMM kernels (#310).
|
||||
- Added multi-D GEMM client APIs (#534).
|
||||
- Added multi-embeddings support (#542).
|
||||
- Added Navi3x blockwise GEMM and real GEMM support (#541).
|
||||
- Added Navi grouped ConvBwdWeight support (#505).
|
||||
|
||||
### Changed
|
||||
- Changed ...
|
||||
@@ -60,7 +60,7 @@ RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb
|
||||
ARG PREFIX=/opt/rocm
|
||||
# Install packages for processing the performance results
|
||||
RUN pip3 install --upgrade pip
|
||||
RUN pip3 install sqlalchemy
|
||||
RUN pip3 install sqlalchemy==1.4.46
|
||||
RUN pip3 install pymysql
|
||||
RUN pip3 install pandas
|
||||
RUN pip3 install setuptools-rust
|
||||
|
||||
38
Jenkinsfile
vendored
38
Jenkinsfile
vendored
@@ -19,7 +19,14 @@ def runShell(String command){
|
||||
}
|
||||
|
||||
def getDockerImageName(){
|
||||
def img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
|
||||
def img
|
||||
if (params.COMPILER_COMMIT == ""){
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
|
||||
}
|
||||
else{
|
||||
def commit = "${params.COMPILER_COMMIT}"[0..6]
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
@@ -464,6 +471,12 @@ def Build_CK(Map conf=[:]){
|
||||
//we only need the ckProfiler to run the performance tests, so we pack and stash it
|
||||
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
|
||||
stash "ckProfiler.tar.gz"
|
||||
if (params.RUN_FULL_QA){
|
||||
// build deb packages
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
|
||||
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -550,8 +563,9 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;COMPILER_VERSION=release
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open''' : ""
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
|
||||
0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT=
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -568,16 +582,16 @@ pipeline {
|
||||
description: "Force building docker image (default: false), set to true if docker image needs to be updated.")
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '5.3',
|
||||
description: 'Specify which ROCM version to use: 5.2.3, or 5.3 (default), etc.')
|
||||
defaultValue: '5.4.3',
|
||||
description: 'Specify which ROCM version to use: 5.4.3 (default).')
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: 'release',
|
||||
description: 'Specify which version of compiler to use: ck-9110, release (default), or amd-stg-open.')
|
||||
defaultValue: 'amd-stg-open',
|
||||
description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).')
|
||||
string(
|
||||
name: 'COMPILER_COMMIT',
|
||||
defaultValue: '',
|
||||
description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit (default), or use 8a82e4eb7ba28521ba9a9424a0315a8a16590424 commit of amd-stg-open branch.')
|
||||
defaultValue: '5541927df00eabd6a110180170eca7785d436ee3',
|
||||
description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
|
||||
string(
|
||||
name: 'BUILD_COMPILER',
|
||||
defaultValue: 'hipcc',
|
||||
@@ -643,8 +657,8 @@ pipeline {
|
||||
{
|
||||
agent{ label rocmnode("gfx908 || gfx90a") }
|
||||
environment{
|
||||
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" """ : """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 " """ }"
|
||||
execute_args = "${params.COMPILER_VERSION == "ck-9110" ? """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ : """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -DCMAKE_CXX_FLAGS="-O3" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ }"
|
||||
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" """ : """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 " """ }"
|
||||
execute_args = "${params.COMPILER_VERSION == "ck-9110" ? """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ : """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ }"
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
@@ -666,7 +680,7 @@ pipeline {
|
||||
options { retry(2) }
|
||||
agent{ label rocmnode("gfx908 || gfx90a")}
|
||||
environment{
|
||||
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}"
|
||||
setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}"
|
||||
}
|
||||
steps{
|
||||
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
|
||||
|
||||
@@ -83,7 +83,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp)
|
||||
target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations)
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations)
|
||||
|
||||
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
|
||||
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations)
|
||||
|
||||
@@ -190,7 +190,7 @@ int main()
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
// DataType
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using HDataType = F16;
|
||||
|
||||
// Layout
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using HLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddReluAdd;
|
||||
using HElementOp = PassThrough;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size)
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
std::size_t mMemSize_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideD0 = 0;
|
||||
ck::index_t StrideD1 = N;
|
||||
ck::index_t StrideH = N;
|
||||
|
||||
float epsilon = 1e-5;
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
|
||||
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
|
||||
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
|
||||
f_matrix_space_size(M, N, StrideD0, D0Layout{}));
|
||||
SimpleDeviceMem d1_device_buf(sizeof(D1DataType) *
|
||||
f_matrix_space_size(M, N, StrideD1, D1Layout{}));
|
||||
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
|
||||
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
|
||||
SimpleDeviceMem h_device_buf(sizeof(HDataType) * f_matrix_space_size(M, N, StrideH, HLayout{}));
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm<
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
HLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
HDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddReluAdd,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
const auto h_element_op = HElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
h_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{StrideD0, StrideD1},
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
SimpleDeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
h_device_buf.SetZero();
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
(sizeof(D0DataType) + sizeof(D1DataType) + sizeof(HDataType)) * M * N +
|
||||
(sizeof(GammaDataType) + sizeof(BetaDataType)) * N;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << std::endl;
|
||||
|
||||
if(ave_time < best_ave_time)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
h_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{StrideD0, StrideD1},
|
||||
StrideH,
|
||||
epsilon,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
SimpleDeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
h_device_buf.SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -12,12 +12,12 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 2;
|
||||
constexpr int NumReduceDim = 1;
|
||||
@@ -54,7 +54,7 @@ int main(int argc, char* argv[])
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations)
|
||||
5
client_example/07_grouped_convnd_fwd/CMakeLists.txt
Normal file
5
client_example/07_grouped_convnd_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations)
|
||||
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_operations)
|
||||
229
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
Normal file
229
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
Normal file
@@ -0,0 +1,229 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNWK;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 1;
|
||||
static constexpr ck::index_t G = 32;
|
||||
static constexpr ck::index_t N = 256;
|
||||
static constexpr ck::index_t K = 192;
|
||||
static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_lengths{G, N, Wi, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_strides{0, 0, 0, 1};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_lengths{G, K, X, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_strides{0, 0, 0, 1};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_lengths{G, N, Wo, K};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_strides{0, 0, 0, 1};
|
||||
|
||||
std::partial_sum(rbegin(in_lengths),
|
||||
std::prev(rend(in_lengths)),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(wei_lengths),
|
||||
std::prev(rend(wei_lengths)),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(out_lengths),
|
||||
std::prev(rend(out_lengths)),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::multiplies<>{});
|
||||
|
||||
// transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW
|
||||
std::rotate(rbegin(in_lengths),
|
||||
std::next(rbegin(in_lengths)),
|
||||
std::next(rbegin(in_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(in_strides),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::next(rbegin(in_strides), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(wei_lengths),
|
||||
std::next(rbegin(wei_lengths)),
|
||||
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(wei_strides),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::next(rbegin(wei_strides), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(out_lengths),
|
||||
std::next(rbegin(out_lengths)),
|
||||
std::next(rbegin(out_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(out_strides),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::next(rbegin(out_strides), NumDimSpatial + 1));
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> filter_strides{1};
|
||||
std::array<ck::index_t, NumDimSpatial> filter_dilations{1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X;
|
||||
std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C +
|
||||
sizeof(WeiDataType) * G * K * X * C +
|
||||
sizeof(OutDataType) * G * N * Wo * K;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,5 @@
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations)
|
||||
|
||||
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
|
||||
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
|
||||
|
||||
226
client_example/08_fused_attention/fused_attention_bias.cpp
Normal file
226
client_example/08_fused_attention/fused_attention_bias.cpp
Normal file
@@ -0,0 +1,226 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr static auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using B0DataType = ck::half_t;
|
||||
using B1DataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using D0DataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
int G0 = 48;
|
||||
int G1 = 16;
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 64;
|
||||
int O = 64;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
|
||||
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
|
||||
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
|
||||
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
|
||||
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
|
||||
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
ck::Tuple<>,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
MaskingSpec>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
AElementOp{},
|
||||
B0ElementOp{},
|
||||
Acc0ElementOp{1 / sqrtf(K)},
|
||||
B1ElementOp{},
|
||||
CElementOp{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
|
||||
sizeof(D0DataType) * M * N) *
|
||||
G0 * G1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best instance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
AElementOp{},
|
||||
B0ElementOp{},
|
||||
Acc0ElementOp{1 / sqrtf(K)},
|
||||
B1ElementOp{},
|
||||
CElementOp{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,2 +1,9 @@
|
||||
add_executable(client_grouped_conv2d_bwd_weight grouped_conv2d_bwd_weight.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_weight PRIVATE composable_kernel::device_operations)
|
||||
add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp)
|
||||
|
||||
target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations)
|
||||
|
||||
@@ -13,27 +13,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 2;
|
||||
static constexpr ck::index_t G = 32;
|
||||
static constexpr ck::index_t N = 256;
|
||||
static constexpr ck::index_t K = 192;
|
||||
static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
@@ -50,22 +31,93 @@ struct SimpleDeviceMem
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
template <ck::index_t NumDimSpatial>
|
||||
std::size_t GetFlops(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
{
|
||||
std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
|
||||
std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
|
||||
std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G * N * K * C *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
template <typename InDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetInputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * (G * N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetWeightByte(ck::index_t G,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * (G * K * C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetOutputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G * N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_grouped_conv_bwd_weight(
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_right_pads)
|
||||
{
|
||||
|
||||
ck::index_t split_k = 2;
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K);
|
||||
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths));
|
||||
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths));
|
||||
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths));
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -120,10 +172,12 @@ int main()
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X;
|
||||
std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C +
|
||||
sizeof(WeiDataType) * G * K * Y * X * C +
|
||||
sizeof(OutDataType) * G * N * Ho * Wo * K;
|
||||
std::size_t flop =
|
||||
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths);
|
||||
std::size_t num_bytes =
|
||||
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) +
|
||||
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
|
||||
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
@@ -149,7 +203,7 @@ int main()
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
@@ -187,4 +241,6 @@ int main()
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 1;
|
||||
static constexpr ck::index_t G = 32;
|
||||
static constexpr ck::index_t N = 256;
|
||||
static constexpr ck::index_t K = 192;
|
||||
static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_bwd_weight<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 2;
|
||||
static constexpr ck::index_t G = 32;
|
||||
static constexpr ck::index_t N = 256;
|
||||
static constexpr ck::index_t K = 192;
|
||||
static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_bwd_weight<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 8;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 128;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_bwd_weight<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using WeiDataType = float;
|
||||
using OutDataType = float;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 8;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 128;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_bwd_weight<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
5
client_example/15_convnd_bwd_data/CMakeLists.txt
Normal file
5
client_example/15_convnd_bwd_data/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
|
||||
target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_operations)
|
||||
233
client_example/15_convnd_bwd_data/common.hpp
Normal file
233
client_example/15_convnd_bwd_data/common.hpp
Normal file
@@ -0,0 +1,233 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
std::size_t GetFlops(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths,
|
||||
const std::vector<ck::index_t>& weights_spatial_lengths)
|
||||
{
|
||||
// 2 * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
|
||||
return static_cast<std::size_t>(2) * N * K * C *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(weights_spatial_lengths),
|
||||
std::end(weights_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType>
|
||||
std::size_t
|
||||
GetInputByte(ck::index_t N, ck::index_t C, const std::vector<ck::index_t>& input_spatial_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t
|
||||
GetWeightByte(ck::index_t K, ck::index_t C, const std::vector<ck::index_t>& weights_spatial_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * K * C *
|
||||
std::accumulate(std::begin(weights_spatial_lengths),
|
||||
std::end(weights_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t
|
||||
GetOutputByte(ck::index_t N, ck::index_t K, const std::vector<ck::index_t>& output_spatial_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_conv_bwd_data(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::vector<ck::index_t>& in_spatial_lengths,
|
||||
const std::vector<ck::index_t>& wei_spatial_lengths,
|
||||
const std::vector<ck::index_t>& out_spatial_lengths)
|
||||
{
|
||||
std::size_t in_mem_size = GetInputByte<InDataType>(N, C, in_spatial_lengths);
|
||||
std::size_t wei_mem_size = GetWeightByte<WeiDataType>(K, C, wei_spatial_lengths);
|
||||
std::size_t out_mem_size = GetOutputByte<OutDataType>(N, K, out_spatial_lengths);
|
||||
|
||||
SimpleDeviceMem in(in_mem_size);
|
||||
SimpleDeviceMem wei(wei_mem_size);
|
||||
SimpleDeviceMem out(out_mem_size);
|
||||
|
||||
std::vector<ck::index_t> filter_strides(NumDimSpatial, 1);
|
||||
std::vector<ck::index_t> filter_dilations(NumDimSpatial, 1);
|
||||
std::vector<ck::index_t> input_left_pads(NumDimSpatial, 1);
|
||||
std::vector<ck::index_t> input_right_pads(NumDimSpatial, 1);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
std::size_t flop = GetFlops(N, K, C, out_spatial_lengths, wei_spatial_lengths);
|
||||
std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_spatial_lengths,
|
||||
wei_spatial_lengths,
|
||||
out_spatial_lengths,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_spatial_lengths,
|
||||
wei_spatial_lengths,
|
||||
out_spatial_lengths,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
42
client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp
Normal file
42
client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::KZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_conv_bwd_data<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
42
client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp
Normal file
42
client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using WeiDataType = float;
|
||||
using OutDataType = float;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::KZYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_conv_bwd_data<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
|
||||
5
client_example/16_convnd_fwd/CMakeLists.txt
Normal file
5
client_example/16_convnd_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
|
||||
add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
|
||||
|
||||
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
|
||||
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
|
||||
304
client_example/16_convnd_fwd/common.hpp
Normal file
304
client_example/16_convnd_fwd/common.hpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
|
||||
{
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
ck::index_t G = weights_lengths[0];
|
||||
ck::index_t N = output_lengths[1];
|
||||
ck::index_t K = weights_lengths[1];
|
||||
ck::index_t C = weights_lengths[2];
|
||||
|
||||
return static_cast<std::size_t>(2) * G * N * K * C *
|
||||
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
|
||||
std::end(weights_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
|
||||
std::end(input_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
|
||||
std::end(weights_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t NumNonSpatialDim = 3>
|
||||
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
|
||||
{
|
||||
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
|
||||
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
|
||||
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
|
||||
|
||||
SimpleDeviceMem in(in_mem_size);
|
||||
SimpleDeviceMem wei(wei_mem_size);
|
||||
SimpleDeviceMem out(out_mem_size);
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
|
||||
in_strides.fill(0);
|
||||
wei_strides.fill(0);
|
||||
out_strides.fill(0);
|
||||
in_strides.back() = 1;
|
||||
wei_strides.back() = 1;
|
||||
out_strides.back() = 1;
|
||||
|
||||
std::partial_sum(rbegin(in_lengths),
|
||||
std::prev(rend(in_lengths)),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(wei_lengths),
|
||||
std::prev(rend(wei_lengths)),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(out_lengths),
|
||||
std::prev(rend(out_lengths)),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::multiplies<>{});
|
||||
|
||||
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
|
||||
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
|
||||
std::rotate(rbegin(in_lengths),
|
||||
std::next(rbegin(in_lengths)),
|
||||
std::next(rbegin(in_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
|
||||
std::rotate(rbegin(in_strides),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::next(rbegin(in_strides), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 2), rend(wei_lengths));
|
||||
std::rotate(rbegin(wei_lengths),
|
||||
std::next(rbegin(wei_lengths)),
|
||||
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 2), rend(wei_strides));
|
||||
std::rotate(rbegin(wei_strides),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::next(rbegin(wei_strides), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
|
||||
std::rotate(rbegin(out_lengths),
|
||||
std::next(rbegin(out_lengths)),
|
||||
std::next(rbegin(out_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
|
||||
std::rotate(rbegin(out_strides),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::next(rbegin(out_strides), NumDimSpatial + 1));
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads;
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads;
|
||||
conv_filter_strides.fill(1);
|
||||
conv_filter_dilations.fill(1);
|
||||
input_left_pads.fill(1);
|
||||
input_right_pads.fill(1);
|
||||
|
||||
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths);
|
||||
std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
44
client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
Normal file
44
client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::KZYXGC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
44
client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
Normal file
44
client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using WeiDataType = float;
|
||||
using OutDataType = float;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::KZYXGC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
static constexpr ck::index_t G = 1;
|
||||
static constexpr ck::index_t N = 64;
|
||||
static constexpr ck::index_t K = 128;
|
||||
static constexpr ck::index_t C = 64;
|
||||
static constexpr ck::index_t Z = 3;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Di = 28;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
|
||||
int main()
|
||||
{
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
191
doc/markdown/tutorial_hello_world.md
Normal file
191
doc/markdown/tutorial_hello_world.md
Normal file
@@ -0,0 +1,191 @@
|
||||
## CK Hello world
|
||||
|
||||
## Motivation
|
||||
|
||||
This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever.
|
||||
|
||||
During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project.
|
||||
|
||||
## Description
|
||||
|
||||
Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more.
|
||||
|
||||
So how do we (almost) reach the speed of light? CK acceleration abilities are based on:
|
||||
|
||||
* Layered structure.
|
||||
* Tile-based computation model.
|
||||
* Tensor coordinate transformation.
|
||||
* Hardware acceleration use.
|
||||
* Support of low precision data types including fp16, bf16, int8 and int4.
|
||||
|
||||
If you are excited and need more technical details and benchmarking results - read this awesome blog [post](https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224).
|
||||
|
||||
For more details visit our [github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel).
|
||||
|
||||
## Hardware targets
|
||||
|
||||
CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture
|
||||
|
||||
GPU Target AMD GPU
|
||||
gfx908 Radeon Instinct MI100
|
||||
gfx90a Radeon Instinct MI210, MI250, MI250X
|
||||
gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT
|
||||
|
||||
There are also [cloud options](https://aws.amazon.com/ec2/instance-types/g4/) you can find if you don't have an AMD GPU at hand.
|
||||
|
||||
## Build the library
|
||||
|
||||
First let's clone the library and rebase to the tested version:
|
||||
|
||||
```
|
||||
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git
|
||||
cd composable_kernel/
|
||||
git checkout tutorial_hello_world
|
||||
```
|
||||
|
||||
To make our lives easier we prepared [docker images](https://hub.docker.com/r/rocm/composable_kernel) with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version.
|
||||
|
||||
If your current folder is ${HOME}, start the docker container with
|
||||
|
||||
```
|
||||
docker run \
|
||||
-it \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${HOME}:/root/workspace \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure.
|
||||
|
||||
Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library
|
||||
|
||||
```
|
||||
cd composable_kernel/
|
||||
```
|
||||
|
||||
Create and go to the "build" directory
|
||||
|
||||
```
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag
|
||||
|
||||
```
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx1030" ..
|
||||
```
|
||||
|
||||
If everything went well the cmake run will end up with:
|
||||
|
||||
```
|
||||
-- Configuring done
|
||||
-- Generating done
|
||||
-- Build files have been written to: "/root/workspace/composable_kernel/build"
|
||||
```
|
||||
|
||||
Finally, we can build examples and tests
|
||||
|
||||
```
|
||||
make -j examples tests
|
||||
```
|
||||
|
||||
If everything is smooth, you'll see
|
||||
|
||||
```
|
||||
Scanning dependencies of target tests
|
||||
[100%] Built target tests
|
||||
```
|
||||
|
||||
## Run examples and tests
|
||||
|
||||
Examples are listed as test cases as well, so we can run all examples and tests with
|
||||
|
||||
```
|
||||
ctest
|
||||
```
|
||||
|
||||
You can check the list of all tests by running
|
||||
|
||||
```
|
||||
ctest -N
|
||||
```
|
||||
|
||||
We can also run them separately, here is a separate example execution.
|
||||
|
||||
```
|
||||
./bin/example_gemm_xdl_fp16 1 1 1
|
||||
```
|
||||
|
||||
The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change.
|
||||
|
||||
If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like
|
||||
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1
|
||||
```
|
||||
|
||||
Meanwhile, running it on a gfx1030 device should result in
|
||||
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem
|
||||
```
|
||||
|
||||
But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like
|
||||
|
||||
```
|
||||
./bin/example_gemm_dl_fp16 1 1 1
|
||||
```
|
||||
|
||||
and it should result in something nice similar to
|
||||
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2}
|
||||
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1>
|
||||
```
|
||||
|
||||
Or we can run a separate test
|
||||
|
||||
```
|
||||
ctest -R test_gemm_fp16
|
||||
```
|
||||
|
||||
If everything goes well you should see something like
|
||||
|
||||
```
|
||||
Start 121: test_gemm_fp16
|
||||
1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec
|
||||
|
||||
100% tests passed, 0 tests failed out of 1
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task.
|
||||
|
||||
P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure!
|
||||
2453
docs/Doxyfile
Normal file
2453
docs/Doxyfile
Normal file
File diff suppressed because it is too large
Load Diff
15
docs/run_doc.sh
Executable file
15
docs/run_doc.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Make this directory the PWD
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Build doxygen info
|
||||
bash run_doxygen.sh
|
||||
|
||||
# Build sphinx docs
|
||||
cd source
|
||||
make clean
|
||||
make -e SPHINXOPTS="-t html" html
|
||||
make latexpdf
|
||||
10
docs/run_doxygen.sh
Executable file
10
docs/run_doxygen.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Make this directory the PWD
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Build the doxygen info
|
||||
rm -rf docBin
|
||||
doxygen Doxyfile
|
||||
23
docs/source/API_Reference_Guide.rst
Normal file
23
docs/source/API_Reference_Guide.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
|
||||
===================
|
||||
API Reference Guide
|
||||
===================
|
||||
|
||||
------------
|
||||
Introduction
|
||||
------------
|
||||
|
||||
This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design
|
||||
principles that are used to write new classes that extend CK functionality.
|
||||
|
||||
=================
|
||||
Using CK API
|
||||
=================
|
||||
|
||||
This section describes how to use the CK library API.
|
||||
|
||||
-----------------
|
||||
CK Datatypes
|
||||
-----------------
|
||||
|
||||
[TODO]
|
||||
8
docs/source/Contributors_Guide.rst
Normal file
8
docs/source/Contributors_Guide.rst
Normal file
@@ -0,0 +1,8 @@
|
||||
===================
|
||||
Contributor's Guide
|
||||
===================
|
||||
|
||||
Pull-request guidelines
|
||||
=======================
|
||||
|
||||
[TODO]
|
||||
13
docs/source/Disclaimer.rst
Normal file
13
docs/source/Disclaimer.rst
Normal file
@@ -0,0 +1,13 @@
|
||||
************
|
||||
Disclaimer
|
||||
************
|
||||
-------------------------------
|
||||
AMD's standard legal Disclaimer
|
||||
-------------------------------
|
||||
|
||||
The information presented in this document is for informational purposes only and may contain technical inaccuracies, omissions, and typographical errors. The information contained herein is subject to change and may be rendered inaccurate for many reasons, including but not limited to product and roadmap changes, component and motherboard version changes, new model and/or product releases, product differences between differing manufacturers, software changes, BIOS flashes, firmware upgrades, or the like. Any computer system has risks of security vulnerabilities that cannot be completely prevented or mitigated. AMD assumes no obligation to update or otherwise correct or revise this information. However, AMD reserves the right to revise this information and to make changes from time to time to the content hereof without obligation of AMD to notify any person of such revisions or changes. THIS INFORMATION IS PROVIDED 'AS IS." AMD MAKES NO REPRESENTATIONS OR WARRANTIES WITH RESPECT TO THE CONTENTS HEREOF AND ASSUMES NO RESPONSIBILITY FOR ANY INACCURACIES, ERRORS, OR OMISSIONS THAT MAY APPEAR IN THIS INFORMATION. AMD SPECIFICALLY DISCLAIMS ANY IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR ANY PARTICULAR PURPOSE. IN NO EVENT WILL AMD BE LIABLE TO ANY PERSON FOR ANY RELIANCE, DIRECT, INDIRECT, SPECIAL, OR OTHER CONSEQUENTIAL DAMAGES ARISING FROM THE USE OF ANY INFORMATION CONTAINED HEREIN, EVEN IF AMD IS EXPRESSLY ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. AMD, the AMD Arrow logo, Radeon, Ryzen, Epyc, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. Google(R) is a registered trademark of Google LLC. PCIe(R) is a registered trademark of PCI-SIG Corporation. Linux(R) is the registered trademark of Linus Torvalds in the U.S. and other countries. Ubuntu(R) and the Ubuntu logo are registered trademarks of Canonical Ltd. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. (C)2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
----------------------
|
||||
Third Party Disclaimer
|
||||
----------------------
|
||||
Third-party content is licensed to you directly by the third party that owns the content and is not licensed to you by AMD. ALL LINKED THIRD-PARTY CONTENT IS PROVIDED "AS IS" WITHOUT A WARRANTY OF ANY KIND. USE OF SUCH THIRD-PARTY CONTENT IS DONE AT YOUR SOLE DISCRETION AND UNDER NO CIRCUMSTANCES WILL AMD BE LIABLE TO YOU FOR ANY THIRD-PARTY CONTENT. YOU ASSUME ALL RISK AND ARE SOLELY RESPONSIBLE FOR ANY DAMAGES THAT MAY ARISE FROM YOUR USE OF THIRD-PARTY CONTENT.
|
||||
15
docs/source/Linux_Install_Guide.rst
Normal file
15
docs/source/Linux_Install_Guide.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
=====================
|
||||
Getting Started Guide
|
||||
=====================
|
||||
|
||||
------------
|
||||
Introduction
|
||||
------------
|
||||
|
||||
This document contains instructions for installing, using, and contributing to Composable Kernel (CK).
|
||||
|
||||
Documentation Roadmap
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
The following is a list of CK documents in the suggested reading order:
|
||||
|
||||
[TODO]
|
||||
20
docs/source/Makefile
Normal file
20
docs/source/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SPHINXPROJ = CK
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
75
docs/source/Supported_Primitives_Guide.rst
Normal file
75
docs/source/Supported_Primitives_Guide.rst
Normal file
@@ -0,0 +1,75 @@
|
||||
==========================
|
||||
Supported Primitives Guide
|
||||
==========================
|
||||
|
||||
This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference
|
||||
Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK.
|
||||
|
||||
------------
|
||||
Softmax
|
||||
------------
|
||||
|
||||
For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated
|
||||
:math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as,
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{align}
|
||||
m(x) & = m( [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ] ) = \max( m(x^{(1)}),\ldots, m(x^{(T)}) ) \\
|
||||
f(x) & = [\exp( m(x^{(1)}) - m(x) ) f( x^{(1)} )\ | \ \ldots \ | \ \exp( m(x^{(T)}) - m(x) ) f( x^{(T)} )] \\
|
||||
z(x) & = \exp( m(x^{(1)}) - m(x) )\ z(x^{(1)}) + \ldots + \exp( m(x^{(T)}) - m(x) )\ z(x^{(1)}) \\
|
||||
\operatorname{softmax}(x) &= f(x)\ / \ z(x)
|
||||
\end{align}
|
||||
|
||||
where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and
|
||||
:math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar.
|
||||
|
||||
For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size :math:`B_r \times B_c` we can
|
||||
compute the row-wise softmax as follows.
|
||||
|
||||
For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate,
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{align}
|
||||
\tilde{m}_{ij} &= \operatorname{rowmax}( X_{ij} ) \\
|
||||
\tilde{P}_{ij} &= \exp(X_{ij} - \tilde{m}_{ij} ) \\
|
||||
\tilde{z}_{ij} &= \operatorname{rowsum}( P_{ij} ) \\
|
||||
\end{align}
|
||||
|
||||
If :math:`j=1`, initialize running max, running sum, and the first column block of the output,
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{align}
|
||||
m_i &= \tilde{m}_{i1} \\
|
||||
z_i &= \tilde{z}_{i1} \\
|
||||
\tilde{Y}_{i1} &= \diag(\tilde{z}_{ij})^{-1} \tilde{P}_{i1}
|
||||
\end{align}
|
||||
|
||||
Else if :math:`j>1`,
|
||||
|
||||
1. Update running max, running sum and column blocks :math:`k=1` to :math:`k=j-1`
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{align}
|
||||
m^{new}_i &= \max(m_i, \tilde{m}_{ij} ) \\
|
||||
z^{new}_i &= \exp(m_i - m^{new}_i)\ z_i + \exp( \tilde{m}_{ij} - m^{new}_i )\ \tilde{z}_{ij} \\
|
||||
Y_{ik} &= \diag(z^{new}_{i})^{-1} \diag(z_{i}) \exp(m_i - m^{new}_i)\ Y_{ik}
|
||||
\end{align}
|
||||
|
||||
2. Initialize column block :math:`j` of output and reset running max and running sum variables:
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{align}
|
||||
\tilde{Y}_{ij} &= \diag(z^{new}_{i})^{-1} \exp(\tilde{m}_{ij} - m^{new}_i ) \tilde{P}_{ij} \\
|
||||
z_i &= z^{new}_i \\
|
||||
m_i &= m^{new}_i \\
|
||||
\end{align}
|
||||
216
docs/source/conf.py
Normal file
216
docs/source/conf.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
|
||||
ies of the Software, and to permit persons to whom the Software is furnished
|
||||
to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
|
||||
PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
|
||||
CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Composable Kernel (CK) docuumentation build configuration file, based on
|
||||
# rocBLAS documentation build configuration file, created by
|
||||
# sphinx-quickstart on Mon Jan 8 16:34:42 2018.
|
||||
#
|
||||
# This file is execfile()d with the current directory set to its
|
||||
# containing dir.
|
||||
#
|
||||
# Note that not all possible configuration values are present in this
|
||||
# autogenerated file.
|
||||
#
|
||||
# All configuration values have a default; values that are commented out
|
||||
# serve to show the default.
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
read_the_docs_build = os.environ.get('READTHEDOCS', None) == 'True'
|
||||
|
||||
if read_the_docs_build:
|
||||
subprocess.call('../run_doxygen.sh')
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = ['sphinx.ext.mathjax', 'breathe']
|
||||
breathe_projects = { "CK": "../docBin/xml" }
|
||||
breathe_default_project = "CK"
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# General information about the project.
|
||||
project = u'Composable Kernel (CK)'
|
||||
copyright = u'2018-2023, Advanced Micro Devices'
|
||||
author = u'Advanced Micro Devices'
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
#version = u'0.8'
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
#release = u'0.8'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = 'en'
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
||||
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
||||
todo_include_todos = False
|
||||
|
||||
|
||||
# -- Options for HTML output ----------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
# html_theme = 'alabaster'
|
||||
|
||||
#if read_the_docs_build:
|
||||
# html_theme = 'default'
|
||||
#else:
|
||||
import sphinx_rtd_theme
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
html_logo = "rocm_logo.png"
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
html_theme_options = {
|
||||
'logo_only': True,
|
||||
'display_version': True
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
#html_static_path = ['_static']
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# This is required for the alabaster theme
|
||||
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
||||
# html_sidebars = {
|
||||
# '**': [
|
||||
# 'relations.html', # needs 'show_related': True theme option to display
|
||||
# 'searchbox.html',
|
||||
# ]
|
||||
# }
|
||||
|
||||
mathjax3_config = {
|
||||
'tex': {
|
||||
'macros': {
|
||||
'diag': '\\operatorname{diag}',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# -- Options for HTMLHelp output ------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'CKdoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ---------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
'preamble': r'''
|
||||
\setcounter{tocdepth}{5}
|
||||
\newcommand{\diag}{\operatorname{diag}}
|
||||
''',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'CK.tex', u'Composabl Kernel (CK) Documentation',
|
||||
u'Advanced Micro Devices', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ---------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'ck', u'Composable Kernel (CK) Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'CK', u'Composable Kernel (CK) Documentation',
|
||||
author, 'CK', 'Composable Kernel for AMD ROCm',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
96
docs/source/dockerhub.rst
Normal file
96
docs/source/dockerhub.rst
Normal file
@@ -0,0 +1,96 @@
|
||||
===================
|
||||
CK docker hub
|
||||
===================
|
||||
|
||||
`Docker hub <https://hub.docker.com/r/rocm/composable_kernel>`_
|
||||
|
||||
-------------------------------------
|
||||
Why do I need this?
|
||||
-------------------------------------
|
||||
|
||||
To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images.
|
||||
|
||||
-------------------------------------
|
||||
So what is Composable Kernel?
|
||||
-------------------------------------
|
||||
|
||||
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
|
||||
|
||||
To get the CK library::
|
||||
|
||||
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git
|
||||
|
||||
|
||||
|
||||
run a docker container::
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
|
||||
/bin/bash
|
||||
|
||||
and build the CK::
|
||||
|
||||
mkdir build && cd build
|
||||
# Need to specify target ID, example below is for gfx908 and gfx90a
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
..
|
||||
|
||||
and::
|
||||
|
||||
make -j examples tests
|
||||
|
||||
To run all the test cases including tests and examples run::
|
||||
|
||||
make test
|
||||
|
||||
We can also run specific examples or tests like::
|
||||
|
||||
./bin/example_gemm_xdl_fp16
|
||||
./bin/test_gemm_fp16
|
||||
|
||||
For more details visit `CK github repo <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_, `CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example)>`_, `even more CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example>`_.
|
||||
|
||||
-------------------------------------
|
||||
And what is inside?
|
||||
-------------------------------------
|
||||
|
||||
The docker images have everything you need for running CK including:
|
||||
|
||||
* `ROCm <https://www.amd.com/en/graphics/servers-solutions-rocm>`_
|
||||
* `CMake <https://cmake.org/>`_
|
||||
* `Compiler <https://github.com/RadeonOpenCompute/llvm-project>`_
|
||||
|
||||
-------------------------------------
|
||||
Which image is right for me?
|
||||
-------------------------------------
|
||||
|
||||
Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are:
|
||||
|
||||
* "ck" - made for running Composable Kernel
|
||||
* "ub20.04" - based on Ubuntu 20.04
|
||||
* "rocm5.4" - ROCm platform version 5.4
|
||||
* "release" - compiler version is release
|
||||
|
||||
So just pick the right image for your project dependencies and you're all set.
|
||||
|
||||
-------------------------------------
|
||||
DIY starts here
|
||||
-------------------------------------
|
||||
|
||||
If you need to customize a docker image or just can't stop tinkering, feel free to adjust the `Dockerfile <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile>`_ for your needs.
|
||||
|
||||
-------------------------------------
|
||||
License
|
||||
-------------------------------------
|
||||
|
||||
CK is released under the MIT `license <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/LICENSE>`_.
|
||||
16
docs/source/index.rst
Normal file
16
docs/source/index.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
============================
|
||||
Composable Kernel User Guide
|
||||
============================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 5
|
||||
:caption: Contents:
|
||||
:numbered:
|
||||
|
||||
Linux_Install_Guide
|
||||
tutorial_hello_world
|
||||
dockerhub
|
||||
Supported_Primitives_Guide
|
||||
API_Reference_Guide
|
||||
Contributors_Guide
|
||||
Disclaimer
|
||||
BIN
docs/source/rocm_logo.png
Normal file
BIN
docs/source/rocm_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 347 KiB |
174
docs/source/tutorial_hello_world.rst
Normal file
174
docs/source/tutorial_hello_world.rst
Normal file
@@ -0,0 +1,174 @@
|
||||
===============
|
||||
CK Hello world
|
||||
===============
|
||||
|
||||
-------------------------------------
|
||||
Motivation
|
||||
-------------------------------------
|
||||
|
||||
This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever.
|
||||
|
||||
During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project.
|
||||
|
||||
-------------------------------------
|
||||
Description
|
||||
-------------------------------------
|
||||
|
||||
Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more.
|
||||
|
||||
So how do we (almost) reach the speed of light? CK acceleration abilities are based on:
|
||||
|
||||
* Layered structure.
|
||||
* Tile-based computation model.
|
||||
* Tensor coordinate transformation.
|
||||
* Hardware acceleration use.
|
||||
* Support of low precision data types including fp16, bf16, int8 and int4.
|
||||
|
||||
If you are excited and need more technical details and benchmarking results - read this awesome `blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_.
|
||||
|
||||
For more details visit our `github repo <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_.
|
||||
|
||||
-------------------------------------
|
||||
Hardware targets
|
||||
-------------------------------------
|
||||
|
||||
CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture
|
||||
|
||||
========== =========
|
||||
GPU Target AMD GPU
|
||||
========== =========
|
||||
gfx908 Radeon Instinct MI100
|
||||
gfx90a Radeon Instinct MI210, MI250, MI250X
|
||||
gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT
|
||||
========== =========
|
||||
|
||||
There are also `cloud options <https://aws.amazon.com/ec2/instance-types/g4/>`_ you can find if you don't have an AMD GPU at hand.
|
||||
|
||||
-------------------------------------
|
||||
Build the library
|
||||
-------------------------------------
|
||||
|
||||
First let's clone the library and rebase to the tested version::
|
||||
|
||||
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git
|
||||
cd composable_kernel/
|
||||
git checkout tutorial_hello_world
|
||||
|
||||
To make our lives easier we prepared `docker images <https://hub.docker.com/r/rocm/composable_kernel>`_ with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version.
|
||||
|
||||
If your current folder is ${HOME}, start the docker container with::
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${HOME}:/root/workspace \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
|
||||
/bin/bash
|
||||
|
||||
If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure.
|
||||
|
||||
Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library::
|
||||
|
||||
cd composable_kernel/
|
||||
|
||||
Create and go to the "build" directory::
|
||||
|
||||
mkdir build && cd build
|
||||
|
||||
In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag::
|
||||
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx1030" ..
|
||||
|
||||
If everything went well the cmake run will end up with::
|
||||
|
||||
-- Configuring done
|
||||
-- Generating done
|
||||
-- Build files have been written to: "/root/workspace/composable_kernel/build"
|
||||
|
||||
Finally, we can build examples and tests::
|
||||
|
||||
make -j examples tests
|
||||
|
||||
If everything is smooth, you'll see::
|
||||
|
||||
Scanning dependencies of target tests
|
||||
[100%] Built target tests
|
||||
|
||||
---------------------------
|
||||
Run examples and tests
|
||||
---------------------------
|
||||
|
||||
Examples are listed as test cases as well, so we can run all examples and tests with::
|
||||
|
||||
ctest
|
||||
|
||||
You can check the list of all tests by running::
|
||||
|
||||
ctest -N
|
||||
|
||||
We can also run them separately, here is a separate example execution::
|
||||
|
||||
./bin/example_gemm_xdl_fp16 1 1 1
|
||||
|
||||
The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change.
|
||||
|
||||
If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like::
|
||||
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1
|
||||
|
||||
Meanwhile, running it on a gfx1030 device should result in::
|
||||
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem
|
||||
|
||||
But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like::
|
||||
|
||||
./bin/example_gemm_dl_fp16 1 1 1
|
||||
|
||||
and it should result in something nice similar to::
|
||||
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2}
|
||||
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1>
|
||||
|
||||
Or we can run a separate test::
|
||||
|
||||
ctest -R test_gemm_fp16
|
||||
|
||||
If everything goes well you should see something like::
|
||||
|
||||
Start 121: test_gemm_fp16
|
||||
1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec
|
||||
|
||||
100% tests passed, 0 tests failed out of 1
|
||||
|
||||
-----------
|
||||
Summary
|
||||
-----------
|
||||
|
||||
In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task.
|
||||
|
||||
P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure!
|
||||
@@ -38,7 +38,9 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx1100")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx1100")
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
endif()
|
||||
|
||||
304
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
Normal file
304
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
|
||||
ck::half_t& e, const float& c, const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AlphaBetaAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
alpha = std::stof(argv[4]);
|
||||
beta = std::stof(argv[5]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
|
||||
alpha = std::stof(argv[11]);
|
||||
beta = std::stof(argv[12]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
|
||||
"beta\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -6,3 +6,9 @@ add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd
|
||||
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
|
||||
example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
|
||||
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
using OutDataType = F16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // K0PerBlock
|
||||
2, // K1
|
||||
4, // M1PerThread
|
||||
4, // N1PerThread
|
||||
1, // KPerThread
|
||||
S<8, 2>, // M1N1ThreadClusterM1Xs
|
||||
S<8, 2>, // M1N1ThreadClusterN1Xs
|
||||
S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1
|
||||
S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1
|
||||
S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder
|
||||
S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
|
||||
S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
|
||||
S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1
|
||||
S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1
|
||||
S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder
|
||||
S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
|
||||
S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
|
||||
S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
4>; // CThreadTransferDstScalarPerVector
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
using WeiDataType = F32;
|
||||
@@ -13,6 +15,46 @@ using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
using OutDataType = F16;
|
||||
@@ -12,6 +14,46 @@ using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
|
||||
@@ -1,46 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
@@ -54,8 +14,19 @@ template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
constexpr ck::index_t split_k = 2;
|
||||
|
||||
ck::index_t split_k;
|
||||
// Set split_k = 2 for xdl op, split_k = 1 for dl
|
||||
// Dl op doesn't support split_k > 1
|
||||
// TODO: Add Dl op split_k > 1 support
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
|
||||
{
|
||||
split_k = 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
split_k = 1;
|
||||
}
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
|
||||
InputLayout<NDimSpatial>>(conv_param);
|
||||
@@ -144,7 +115,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
std::cerr << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -116,7 +115,7 @@ auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -15,6 +14,7 @@
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -69,21 +69,20 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
|
||||
// clang-format on
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
return HostTensorDescriptor({len}, {stride});
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
using namespace ck::literals;
|
||||
|
||||
if constexpr(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -97,6 +96,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
|
||||
AElementOp a_element_op,
|
||||
BElementOp b_element_op,
|
||||
CDEElementOp cde_element_op,
|
||||
HElementOp h_element_op,
|
||||
int M,
|
||||
int N,
|
||||
AccDataType epsilon = 1e-5)
|
||||
@@ -145,7 +145,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
|
||||
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
|
||||
|
||||
auto ref_layernorm_argument = ref_layernorm.MakeArgument(
|
||||
e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon);
|
||||
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon);
|
||||
ref_layernorm_invoker.Run(ref_layernorm_argument);
|
||||
}
|
||||
|
||||
@@ -249,6 +249,7 @@ int main()
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
h_element_op,
|
||||
M,
|
||||
N,
|
||||
epsilon);
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
@@ -115,7 +114,7 @@ auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ int main(int argc, char* argv[])
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
if constexpr(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
|
||||
427
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
Normal file
427
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
Normal file
@@ -0,0 +1,427 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F64 = double;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F64;
|
||||
using BDataType = F64;
|
||||
using AccDataType = F64;
|
||||
using CShuffleDataType = F64;
|
||||
using DDataType = F64;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F64;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceKNNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceMKNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceMNNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// A[M0, M1, K0, K1]
|
||||
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
|
||||
// B[N0, N1, K0, K1]
|
||||
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
|
||||
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
|
||||
// D[M0, M1, N0, N1]
|
||||
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
|
||||
// E[M0, M1, N0, N1]
|
||||
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
|
||||
|
||||
float alpha = 1.f;
|
||||
float beta = 1.f;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 28)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck::index_t M0 = std::stoi(argv[4]);
|
||||
const ck::index_t M1 = std::stoi(argv[5]);
|
||||
|
||||
const ck::index_t N0 = std::stoi(argv[6]);
|
||||
const ck::index_t N1 = std::stoi(argv[7]);
|
||||
|
||||
const ck::index_t K0 = std::stoi(argv[8]);
|
||||
const ck::index_t K1 = std::stoi(argv[9]);
|
||||
|
||||
a_ms_ks_lengths = {M0, M1, K0, K1};
|
||||
a_ms_ks_strides = {
|
||||
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
|
||||
|
||||
b_ns_ks_lengths = {N0, N1, K0, K1};
|
||||
b_ns_ks_strides = {
|
||||
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
|
||||
|
||||
d_ms_ns_lengths = {M0, M1, N0, N1};
|
||||
d_ms_ns_strides = {
|
||||
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
|
||||
|
||||
e_ms_ns_lengths = {M0, M1, N0, N1};
|
||||
e_ms_ns_strides = {
|
||||
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
|
||||
|
||||
alpha = std::stof(argv[26]);
|
||||
beta = std::stof(argv[27]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n");
|
||||
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
|
||||
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
|
||||
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
|
||||
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
|
||||
printf("arg26 to 27: alpha, beta\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
|
||||
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
|
||||
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_ns_ks.mData.data());
|
||||
d_device_buf.ToDevice(d_ms_ns.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t M =
|
||||
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(m0, m1, n0, n1),
|
||||
d_ms_ns(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
409
example/26_contraction/contraction_scale_xdl_fp64.cpp
Normal file
409
example/26_contraction/contraction_scale_xdl_fp64.cpp
Normal file
@@ -0,0 +1,409 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F64 = double;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F64;
|
||||
using BDataType = F64;
|
||||
using AccDataType = F64;
|
||||
using CShuffleDataType = F64;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F64;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKKN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceKNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceMKN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
|
||||
using DeviceOpInstanceMNN = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>;
|
||||
// clang-format on
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// A[M0, M1, K0, K1]
|
||||
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
|
||||
// B[N0, N1, K0, K1]
|
||||
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
|
||||
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
|
||||
// E[M0, M1, N0, N1]
|
||||
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
|
||||
|
||||
float scale = 1.f;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck::index_t M0 = std::stoi(argv[4]);
|
||||
const ck::index_t M1 = std::stoi(argv[5]);
|
||||
|
||||
const ck::index_t N0 = std::stoi(argv[6]);
|
||||
const ck::index_t N1 = std::stoi(argv[7]);
|
||||
|
||||
const ck::index_t K0 = std::stoi(argv[8]);
|
||||
const ck::index_t K1 = std::stoi(argv[9]);
|
||||
|
||||
a_ms_ks_lengths = {M0, M1, K0, K1};
|
||||
a_ms_ks_strides = {
|
||||
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
|
||||
|
||||
b_ns_ks_lengths = {N0, N1, K0, K1};
|
||||
b_ns_ks_strides = {
|
||||
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
|
||||
|
||||
e_ms_ns_lengths = {M0, M1, N0, N1};
|
||||
e_ms_ns_strides = {
|
||||
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
|
||||
|
||||
scale = std::stof(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n");
|
||||
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
|
||||
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
|
||||
printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
|
||||
printf("arg22: scale\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
|
||||
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_ns_ks.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{scale};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 0>{},
|
||||
std::array<std::vector<ck::index_t>, 0>{},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t M =
|
||||
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -20,12 +20,12 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
constexpr int Rank = 2;
|
||||
constexpr int NumReduceDim = 1;
|
||||
@@ -34,7 +34,7 @@ using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
@@ -121,7 +121,7 @@ int main()
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,431 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
|
||||
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using DeviceOpInstanceKKNN =
|
||||
ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
ABSpec,
|
||||
ABSpec,
|
||||
DESpec,
|
||||
256,
|
||||
128,
|
||||
256,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimG,
|
||||
ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimG == 2 && NumDimM == 2 && NumDimN == 2 && NumDimK == 1, bool> =
|
||||
false>
|
||||
struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_gs_ms_ks_{a_gs_ms_ks},
|
||||
b_gs_ns_ks_{b_gs_ns_ks},
|
||||
e_gs_ms_ns_{e_gs_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_gs_ms_ks_;
|
||||
const Tensor<BDataType>& b_gs_ns_ks_;
|
||||
Tensor<EDataType>& e_gs_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_G2_M2_N2_K1::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a,
|
||||
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0)));
|
||||
arg.b_element_op_(
|
||||
v_b,
|
||||
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
|
||||
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
|
||||
const Tensor<BDataType>& b_gs_ns_ks,
|
||||
Tensor<EDataType>& e_gs_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{
|
||||
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_G2_M2_N2_K1"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
ck::index_t G0 = 1;
|
||||
ck::index_t G1 = 2;
|
||||
|
||||
ck::index_t M0 = 4;
|
||||
ck::index_t M1 = 128;
|
||||
|
||||
ck::index_t N0 = 16;
|
||||
ck::index_t N1 = 256;
|
||||
|
||||
ck::index_t K0 = 2048;
|
||||
|
||||
// A[G0, G1, M0, M1, K0]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
|
||||
// B[G0, G1, N0, N1, K0]
|
||||
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
|
||||
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
|
||||
|
||||
// D[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
|
||||
// E[G0, G1, M0, N0, M1, N1]
|
||||
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
|
||||
std::vector<ck::index_t> e_gs_ms_ns_strides{
|
||||
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
|
||||
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
|
||||
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) *
|
||||
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
|
||||
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// device operation
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_lengths,
|
||||
b_gs_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
|
||||
e_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t G =
|
||||
ck::accumulate_n<ck::index_t>(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t M = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl;
|
||||
std::size_t flop = std::size_t(2) * G * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
|
||||
sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
|
||||
{
|
||||
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
|
||||
{
|
||||
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
|
||||
++n1)
|
||||
{
|
||||
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
|
||||
d_gs_ms_ns(g0, g1, m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -16,6 +16,9 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100")
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc,
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_param = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
355
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
Normal file
355
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
Normal file
@@ -0,0 +1,355 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP16 = ck::half_t;
|
||||
using FP32 = float;
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
using I4 = ck::int4_t;
|
||||
#endif
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename InputLay, typename WeightLay, typename OutputLay>
|
||||
struct CommonLayoutSetting
|
||||
{
|
||||
using InputLayout = InputLay;
|
||||
using WeightLayout = WeightLay;
|
||||
using OutputLayout = OutputLay;
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
struct CommonLayoutSettingSelector;
|
||||
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<1> final
|
||||
: CommonLayoutSetting<ctl::G_NW_C, ctl::G_K_X_C, ctl::G_NW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<2> final
|
||||
: CommonLayoutSetting<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<3> final
|
||||
: CommonLayoutSetting<ctl::G_NDHW_C, ctl::G_K_ZYX_C, ctl::G_NDHW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using InputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::InputLayout;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using WeightLayout = typename CommonLayoutSettingSelector<NDimSpatial>::WeightLayout;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using OutputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::OutputLayout;
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
ck::utils::conv::ConvParam \
|
||||
{ \
|
||||
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
|
||||
}
|
||||
|
||||
inline void print_help_msg()
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=no, 1=yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
}
|
||||
|
||||
inline bool parse_cmd_args(int argc,
|
||||
char* argv[],
|
||||
ExecutionConfig& config,
|
||||
ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
constexpr int num_execution_config_args =
|
||||
3; // arguments for do_verification, init_method, time_kernel
|
||||
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
|
||||
|
||||
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
|
||||
constexpr int threshold_to_catch_all_args =
|
||||
threshold_to_catch_partial_args + num_conv_param_leading_args;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
// catch only ExecutionConfig arguments
|
||||
else if(argc == threshold_to_catch_partial_args)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
// catch both ExecutionConfig & ConvParam arguments
|
||||
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_param = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_help_msg();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n
|
||||
1, // c
|
||||
conv_param.G_ * conv_param.C_ // wi
|
||||
});
|
||||
|
||||
case 2:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_[0],
|
||||
conv_param.input_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
|
||||
conv_param.G_ * conv_param.C_, // n
|
||||
1, // c
|
||||
conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi
|
||||
conv_param.G_ * conv_param.C_ // wi
|
||||
});
|
||||
|
||||
case 3:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_[0],
|
||||
conv_param.input_spatial_lengths_[1],
|
||||
conv_param.input_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.C_, // g
|
||||
conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] *
|
||||
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n
|
||||
1, // c
|
||||
conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] *
|
||||
conv_param.G_ * conv_param.C_, // di
|
||||
conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi
|
||||
conv_param.G_ * conv_param.C_ // wi
|
||||
});
|
||||
}
|
||||
|
||||
throw std::runtime_error("unsuppored # dim spatial");
|
||||
}
|
||||
|
||||
inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.C_ // x
|
||||
});
|
||||
case 2:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.filter_spatial_lengths_[0],
|
||||
conv_param.filter_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
|
||||
conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
|
||||
conv_param.C_ // x
|
||||
});
|
||||
case 3:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.filter_spatial_lengths_[0],
|
||||
conv_param.filter_spatial_lengths_[1],
|
||||
conv_param.filter_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.K_ * conv_param.filter_spatial_lengths_[0] *
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
|
||||
conv_param.C_, // g
|
||||
conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] *
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k
|
||||
1, // c
|
||||
conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] *
|
||||
conv_param.C_, // z
|
||||
conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y
|
||||
conv_param.C_ // x
|
||||
});
|
||||
}
|
||||
|
||||
throw std::runtime_error("unsuppored # dim spatial");
|
||||
}
|
||||
|
||||
inline HostTensorDescriptor make_bias_descriptor(const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
0, // k
|
||||
1, // c
|
||||
0 // x
|
||||
});
|
||||
case 2:
|
||||
return HostTensorDescriptor({conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.output_spatial_lengths_[0],
|
||||
conv_param.output_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
0, // n
|
||||
1, // k
|
||||
0, // ho
|
||||
0 // wo
|
||||
});
|
||||
case 3:
|
||||
return HostTensorDescriptor({conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.output_spatial_lengths_[0],
|
||||
conv_param.output_spatial_lengths_[1],
|
||||
conv_param.output_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
0, // n
|
||||
1, // k
|
||||
0, // z
|
||||
0, // y
|
||||
0 // x
|
||||
});
|
||||
}
|
||||
|
||||
throw std::runtime_error("unsuppored # dim spatial");
|
||||
}
|
||||
|
||||
inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n
|
||||
1, // k
|
||||
conv_param.G_ * conv_param.K_ // wo
|
||||
});
|
||||
case 2:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.output_spatial_lengths_[0],
|
||||
conv_param.output_spatial_lengths_[1]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
|
||||
conv_param.G_ * conv_param.K_, // n
|
||||
1, // k
|
||||
conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho
|
||||
conv_param.G_ * conv_param.K_ // wo
|
||||
});
|
||||
|
||||
case 3:
|
||||
return HostTensorDescriptor(
|
||||
{conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.output_spatial_lengths_[0],
|
||||
conv_param.output_spatial_lengths_[1],
|
||||
conv_param.output_spatial_lengths_[2]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] *
|
||||
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n
|
||||
1, // k
|
||||
conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] *
|
||||
conv_param.G_ * conv_param.K_, // do
|
||||
conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho
|
||||
conv_param.G_ * conv_param.K_ // wo
|
||||
});
|
||||
}
|
||||
|
||||
throw std::runtime_error("unsuppored # dim spatial");
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common_wmma.hpp"
|
||||
|
||||
// kernel data types
|
||||
using InKernelDataType = FP16;
|
||||
using WeiKernelDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using BiasKernelDataType = FP16;
|
||||
using ResidualKernelDataType = FP16;
|
||||
using OutKernelDataType = FP16;
|
||||
|
||||
// tensor data types
|
||||
using InUserDataType = InKernelDataType;
|
||||
using WeiUserDataType = WeiKernelDataType;
|
||||
using OutUserDataType = OutKernelDataType;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
|
||||
@@ -0,0 +1,286 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename BiasLay, typename ResidualLay>
|
||||
struct LayoutSetting
|
||||
{
|
||||
using BiasLayout = BiasLay;
|
||||
using ResidualLayout = ResidualLay;
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
struct LayoutSettingSelector;
|
||||
|
||||
template <>
|
||||
struct LayoutSettingSelector<1> final : LayoutSetting<ctl::G_K, ctl::G_NW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LayoutSettingSelector<2> final : LayoutSetting<ctl::G_K, ctl::G_NHW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LayoutSettingSelector<3> final : LayoutSetting<ctl::G_K, ctl::G_NDHW_K>
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using BiasLayout = typename LayoutSettingSelector<NDimSpatial>::BiasLayout;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using ResidualLayout = typename LayoutSettingSelector<NDimSpatial>::ResidualLayout;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
NDimSpatial,
|
||||
InputLayout<NDimSpatial>,
|
||||
WeightLayout<NDimSpatial>,
|
||||
ck::Tuple<BiasLayout<NDimSpatial>, ResidualLayout<NDimSpatial>>,
|
||||
OutputLayout<NDimSpatial>,
|
||||
InKernelDataType,
|
||||
WeiKernelDataType,
|
||||
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
|
||||
OutKernelDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
16, // MPerWMMA
|
||||
16, // NPerWMMA
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InUserDataType,
|
||||
WeiUserDataType,
|
||||
CShuffleDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
PassThrough>;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial");
|
||||
|
||||
const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param);
|
||||
const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param);
|
||||
const auto bias_g_n_k_wos_desc = make_bias_descriptor(conv_param);
|
||||
const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param);
|
||||
|
||||
Tensor<InUserDataType> in(in_g_n_c_wis_desc);
|
||||
Tensor<WeiUserDataType> wei(wei_g_k_c_xs_desc);
|
||||
Tensor<OutUserDataType> bias(bias_g_n_k_wos_desc);
|
||||
Tensor<OutUserDataType> residual(bias_g_n_k_wos_desc);
|
||||
Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
|
||||
Tensor<OutKernelDataType> out_device(out_g_n_k_wos_desc);
|
||||
|
||||
std::cout << "in: " << in.mDesc << std::endl;
|
||||
std::cout << "wei: " << wei.mDesc << std::endl;
|
||||
std::cout << "bias: " << bias.mDesc << std::endl;
|
||||
std::cout << "residual: " << residual.mDesc << std::endl;
|
||||
std::cout << "out: " << out_host.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InUserDataType>{-5, 5});
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiUserDataType>{-5, 5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutUserDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InUserDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiUserDataType>{-0.5, 0.5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_3<OutUserDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize());
|
||||
DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
const Tensor<InKernelDataType> in_converted(in);
|
||||
const Tensor<WeiKernelDataType> wei_converted(wei);
|
||||
const Tensor<OutKernelDataType> bias_converted(bias);
|
||||
const Tensor<OutKernelDataType> residual_converted(residual);
|
||||
|
||||
in_device_buf.ToDevice(in_converted.mData.data());
|
||||
wei_device_buf.ToDevice(wei_converted.mData.data());
|
||||
bias_device_buf.ToDevice(bias_converted.mData.data());
|
||||
residual_device_buf.ToDevice(residual_converted.mData.data());
|
||||
#else
|
||||
in_device_buf.ToDevice(in.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
bias_device_buf.ToDevice(bias.mData.data());
|
||||
residual_device_buf.ToDevice(residual.mData.data());
|
||||
#endif
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths);
|
||||
copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides);
|
||||
copy(bias_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths);
|
||||
copy(bias_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides);
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
// do Conv
|
||||
auto conv = DeviceConvFwdInstance<NDimSpatial>{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument =
|
||||
conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 2>{bias_device_buf.GetDeviceBuffer(),
|
||||
residual_device_buf.GetDeviceBuffer()},
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
|
||||
{d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}},
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
|
||||
{d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}},
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = conv_param.GetFlops();
|
||||
std::size_t num_btype = conv_param.GetByte<InUserDataType, WeiUserDataType, OutUserDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< conv.GetTypeString() << std::endl;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_host(out_g_n_k_wos_desc);
|
||||
|
||||
auto ref_conv = HostConvFwdInstance<NDimSpatial>{};
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
wei,
|
||||
c_host,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// TODO: implement elementwise operation for host
|
||||
out_host.ForEach([&](auto&, auto idx) {
|
||||
OutElementOp{}(out_host(idx), c_host(idx), bias(idx), residual(idx));
|
||||
});
|
||||
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
const Tensor<OutUserDataType> out_device_converted(out_device);
|
||||
|
||||
return ck::utils::check_err(
|
||||
out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
#else
|
||||
return ck::utils::check_err(
|
||||
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
#endif
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
|
||||
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
|
||||
case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -23,11 +23,11 @@
|
||||
constexpr int Rank = 5;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using XDataType = ck::half_t;
|
||||
using GammaDataType = ck::half_t;
|
||||
using BetaDataType = ck::half_t;
|
||||
using YDataType = ck::half_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
struct YElementOp
|
||||
{
|
||||
@@ -50,7 +50,7 @@ using DeviceInstance =
|
||||
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
YElementOp,
|
||||
Rank,
|
||||
@@ -157,7 +157,7 @@ int main(int argc, char* argv[])
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YElementOp>;
|
||||
|
||||
ReferenceInstance ref;
|
||||
|
||||
@@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
@@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << "wrong! this device_op instance does not support this problem" << std::endl;
|
||||
return true;
|
||||
std::cout << "wrong! this device_op instance does not support this problem" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
1
example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt
Normal file
1
example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
|
||||
@@ -0,0 +1,408 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using C0DEElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
constexpr static auto MaskingSpec =
|
||||
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
|
||||
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
using D0DataType = F16;
|
||||
using Acc0BiasDataType = ck::Tuple<D0DataType>;
|
||||
using Acc1BiasDataType = ck::Tuple<>;
|
||||
|
||||
static constexpr ck::index_t NumDimG = 2;
|
||||
static constexpr ck::index_t NumDimM = 1;
|
||||
static constexpr ck::index_t NumDimN = 1;
|
||||
static constexpr ck::index_t NumDimK = 1;
|
||||
static constexpr ck::index_t NumDimO = 1;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDimO,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
C0DEElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecA,
|
||||
TensorSpecB0,
|
||||
TensorSpecB1,
|
||||
TensorSpecC,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
64, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
2, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<16, 16, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
MaskingSpec>; // MaskingSpecialization
|
||||
|
||||
// Ref Gemm0: fp16 in, fp32 out
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Softmax: fp32 in, fp16 out
|
||||
using ReferenceSoftmaxInstance =
|
||||
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, ADataType, AccDataType>;
|
||||
|
||||
// Ref Gemm1: fp16 in, fp16 out
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
int G0 = 3;
|
||||
int G1 = 2;
|
||||
int M = 1024;
|
||||
int N = 1024;
|
||||
int K = 64;
|
||||
int O = 64;
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G0 = std::stoi(argv[8]);
|
||||
G1 = std::stoi(argv[9]);
|
||||
|
||||
alpha = std::stof(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 11: M, N, K, O, G0, G1\n");
|
||||
printf("arg10: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{
|
||||
M * G1 * K, K, G1 * K, 1}; // A layout [G0, M, G1, K]
|
||||
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{
|
||||
N * G1 * K, K, G1 * K, 1}; // B0 layout [G0, N, G1, K]
|
||||
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{
|
||||
N * G1 * O, O, 1, G1 * O}; // B1 layout [G0, N, G1, O]
|
||||
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{
|
||||
M * G1 * O, O, G1 * O, 1}; // C layout [G0, M, G1, O]
|
||||
|
||||
// D layout [G0, M, G1, N]
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
|
||||
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
|
||||
|
||||
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
|
||||
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
|
||||
Tensor<D0DataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
|
||||
|
||||
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
|
||||
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
|
||||
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
|
||||
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
|
||||
break;
|
||||
case 3:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
|
||||
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
|
||||
|
||||
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
|
||||
b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data());
|
||||
b1_device_buf.ToDevice(b1_gs_os_ns.mData.data());
|
||||
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
|
||||
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto c0de_element_op = C0DEElementOp{alpha};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
auto argument = device_op.MakeArgument(
|
||||
static_cast<const ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
|
||||
std::array<std::vector<ck::index_t>, 1>{
|
||||
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
ck::index_t BatchCount = G0 * G1;
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
|
||||
sizeof(CDataType) * M * O + sizeof(D0DataType) * M * N) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
|
||||
|
||||
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
|
||||
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
|
||||
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
|
||||
Tensor<D0DataType> d0_g_m_n({BatchCount, M, N});
|
||||
|
||||
// permute
|
||||
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
|
||||
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
|
||||
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
|
||||
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
|
||||
});
|
||||
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
|
||||
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
|
||||
});
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx));
|
||||
});
|
||||
// masking
|
||||
const auto mask = DeviceOpInstance::C0MatrixMask(N);
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
if(mask.IsMaskedElement(idx[1], idx[2]))
|
||||
self(idx) = -ck::NumericLimits<float>::Infinity();
|
||||
});
|
||||
|
||||
// softmax
|
||||
auto ref_softmax = ReferenceSoftmaxInstance{};
|
||||
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
|
||||
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
|
||||
|
||||
ref_softmax_invoker.Run(ref_softmax_argument);
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// permute
|
||||
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
|
||||
const size_t& g0 = idx[0];
|
||||
const size_t& g1 = idx[1];
|
||||
|
||||
const size_t g = g0 * G1 + g1;
|
||||
|
||||
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
|
||||
});
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
return ck::utils::check_err(c_gs_ms_os_device_result.mData,
|
||||
c_gs_ms_os_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -172,13 +172,6 @@
|
||||
// tuning parameter
|
||||
#define CK_WORKAROUND_SWDEV_325164 0
|
||||
|
||||
// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue
|
||||
#ifdef __gfx908__
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1
|
||||
#else // __gfx90a__, ...
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0
|
||||
#endif // __gfx908__
|
||||
|
||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
@@ -29,15 +30,15 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
const int nrepeat = 10;
|
||||
|
||||
printf("Warm up 1 time\n");
|
||||
|
||||
#endif
|
||||
// warm up
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
|
||||
|
||||
const int nrepeat = 10;
|
||||
#if DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
|
||||
@@ -26,9 +26,9 @@ template <index_t NumDimG,
|
||||
typename Acc1BiasDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename Acc0ElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
MaskingSpecialization MaskingSpec>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
{
|
||||
@@ -58,9 +58,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
Acc0ElementwiseOperation acc0_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
C1DEElementwiseOperation c1de_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -14,9 +14,9 @@ namespace device {
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
struct DeviceNormalization : public BaseOperator
|
||||
@@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator
|
||||
void* p_y,
|
||||
void* p_savedMean,
|
||||
void* p_savedInvVar,
|
||||
AccElementwiseOperation acc_elementwise_op) = 0;
|
||||
YElementwiseOperation y_elementwise_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
@@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim>
|
||||
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>>;
|
||||
|
||||
|
||||
@@ -0,0 +1,991 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Tensor Contraction:
|
||||
// input : A
|
||||
// input : B
|
||||
// input : D0, D1, ...
|
||||
// output : E
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
|
||||
// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner
|
||||
// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and
|
||||
// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted
|
||||
// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into
|
||||
// TensorSpecialization::Default with NumDimG/M/N/K = 1
|
||||
//
|
||||
// Detail- Packed tensor satisfies
|
||||
// stride_0 = 1
|
||||
// stride_i = stride_{i - 1} * extent_{i - 1}
|
||||
// So tensor
|
||||
// [G0, G1, G2, M, N]
|
||||
// transposed into tensor
|
||||
// [G0, G2, G1, M, N]
|
||||
// with strides
|
||||
// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1]
|
||||
// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some
|
||||
// strides from input tensor extents so finer dimension information is lost. Merging dimensions is
|
||||
// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1.
|
||||
//
|
||||
// Might need to expose dimension order to the interface to fully support
|
||||
// TensorSpecialization::Packed in a traditional sense of "packed" tensor
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization DESpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
: public DeviceBatchedContractionMultipleD<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
|
||||
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
|
||||
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto a_ms_ks_lengths = to_tuple(
|
||||
a_gs_ms_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
|
||||
const auto a_ms_ks_strides = to_tuple(
|
||||
a_gs_ms_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}
|
||||
|
||||
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
|
||||
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto b_ns_ks_lengths = to_tuple(
|
||||
b_gs_ns_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
|
||||
const auto b_ns_ks_strides = to_tuple(
|
||||
b_gs_ns_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
|
||||
// dimension Ids for K0, K1, ...
|
||||
constexpr auto kDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}
|
||||
|
||||
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides_vec)
|
||||
{
|
||||
assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto e_ms_ns_lengths = to_tuple(
|
||||
e_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto e_ms_ns_strides = to_tuple(
|
||||
e_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
|
||||
|
||||
if constexpr(DESpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N),
|
||||
make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
|
||||
e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto e_grid_desc_ms_ns =
|
||||
make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
|
||||
const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
e_grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
}
|
||||
|
||||
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
static auto MakeEGridDescriptor_G_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides_vec)
|
||||
{
|
||||
assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto e_gs_ms_ns_lengths =
|
||||
to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto e_gs_ms_ns_strides =
|
||||
to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
|
||||
NumDimG + NumDimM + NumDimN,
|
||||
1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for K0, K1, ...
|
||||
const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
if constexpr(DESpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(G, M, N),
|
||||
make_tuple(e_gs_ms_ns_strides[Number<NumDimG - 1>{}],
|
||||
e_gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
|
||||
e_gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
|
||||
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
|
||||
return e_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto e_grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor(
|
||||
e_grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
|
||||
return e_grid_desc_g_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i],
|
||||
ds_gs_ms_ns_strides_vec[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_G_M_N(
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i],
|
||||
ds_gs_ms_ns_strides_vec[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
using DsGridDesc_G_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_G_M_N({}, {}))>;
|
||||
using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A,
|
||||
index_t batch_stride_B,
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n,
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n)
|
||||
: batch_stride_A_(batch_stride_A),
|
||||
batch_stride_B_(batch_stride_B),
|
||||
ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n),
|
||||
e_grid_desc_g_m_n_(e_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(g_idx) * batch_stride_A_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(g_idx) * batch_stride_B_;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
std::array<long_index_t, NumDTensor> ds_offset;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
ds_offset[i] = static_cast<long_index_t>(g_idx) *
|
||||
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
|
||||
});
|
||||
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return static_cast<long_index_t>(g_idx) *
|
||||
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
index_t batch_stride_A_;
|
||||
index_t batch_stride_B_;
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
|
||||
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_n_k_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
|
||||
e_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_mz_stride_{},
|
||||
a_kz_stride_{},
|
||||
b_nz_stride_{},
|
||||
b_kz_stride_{},
|
||||
ds_nz_stride_{},
|
||||
e_nz_stride_{},
|
||||
a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]},
|
||||
b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]},
|
||||
compute_ptr_offset_of_batch_{
|
||||
a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
});
|
||||
|
||||
a_grid_desc_m_k_ =
|
||||
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_n_k_ =
|
||||
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
|
||||
ds_grid_desc_m_n_ =
|
||||
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
|
||||
|
||||
e_grid_desc_m_n_ =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
// for sanity check of vector memory access
|
||||
a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
|
||||
a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1];
|
||||
b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1];
|
||||
b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1];
|
||||
|
||||
for(index_t i = 0; i < NumDTensor; ++i)
|
||||
{
|
||||
ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1];
|
||||
}
|
||||
|
||||
e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1];
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseOp::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// Block to Tile mapping
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// Idle
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
// ElementwiseOp
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// Strides for the last M/N/K dimensions of A/B/Ds/E
|
||||
// for sanity check of vector load/store
|
||||
index_t a_mz_stride_;
|
||||
index_t a_kz_stride_;
|
||||
index_t b_nz_stride_;
|
||||
index_t b_kz_stride_;
|
||||
std::array<index_t, NumDTensor> ds_nz_stride_;
|
||||
index_t e_mz_stride_;
|
||||
index_t e_nz_stride_;
|
||||
|
||||
index_t a_batch_stride_;
|
||||
index_t b_batch_stride_;
|
||||
|
||||
// Batch Offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
typename GridwiseOp::DefaultBlock2CTileMap,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
G,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access
|
||||
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
|
||||
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
|
||||
"wrong!");
|
||||
|
||||
// vector memory access of A: could be on M or AK1 dimension
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of B: could be on N or BK1 dimension
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector memory access of Ds: always on NPerBlock dimension
|
||||
bool valid_d_access = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
if(!(arg.ds_nz_stride_[i] == 1 &&
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetLength(I3) %
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(valid_d_access == false)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector memory access of E: always on NPerBlock dimension
|
||||
if(!((arg.e_nz_stride_ == 1 &&
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3) %
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
|
||||
0) ||
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
|
||||
const std::vector<index_t>& e_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_gs_ms_ks_lengths,
|
||||
b_gs_ns_ks_lengths,
|
||||
ds_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_strides,
|
||||
ds_gs_ms_ns_strides,
|
||||
e_gs_ms_ns_strides,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
|
||||
const std::vector<index_t>& e_gs_ms_ns_lengths,
|
||||
const std::vector<index_t>& e_gs_ms_ns_strides,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_gs_ms_ks_lengths,
|
||||
b_gs_ns_ks_lengths,
|
||||
ds_gs_ms_ns_lengths,
|
||||
e_gs_ms_ns_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b_gs_ns_ks_strides,
|
||||
ds_gs_ms_ns_strides,
|
||||
e_gs_ms_ns_strides,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -25,15 +25,17 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename D0sPointer,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
@@ -47,16 +49,19 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const C0DEElementwiseOperation c0de_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const C1DEElementwiseOperation c1de_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
@@ -78,20 +83,28 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0s_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op,
|
||||
c1de_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
@@ -99,15 +112,17 @@ __global__ void
|
||||
ignore = p_b_grid;
|
||||
ignore = p_b1_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0s_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = acc_element_op;
|
||||
ignore = c0de_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c1de_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
@@ -127,15 +142,15 @@ template <index_t NumDimG,
|
||||
typename BDataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename Acc0BiasDataType,
|
||||
typename Acc1BiasDataType,
|
||||
typename D0sDataType,
|
||||
typename D1sDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename C0DEElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1DEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
@@ -193,23 +208,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
BDataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
Acc0BiasDataType,
|
||||
Acc1BiasDataType,
|
||||
D0sDataType,
|
||||
D1sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
MaskingSpec>
|
||||
{
|
||||
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
|
||||
"Number of dimension must be greater than 0");
|
||||
|
||||
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
|
||||
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
// TODO ANT: implement bias combination
|
||||
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
|
||||
static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented");
|
||||
|
||||
#if 0
|
||||
// TODO ANT: use alias
|
||||
@@ -262,14 +277,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
Number<B1K1>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_G_M_N(
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i],
|
||||
acc0_biases_gs_ms_ns_strides[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
|
||||
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
|
||||
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
|
||||
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
|
||||
using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {}));
|
||||
using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
constexpr static auto make_MaskOutPredicate()
|
||||
{
|
||||
@@ -289,11 +330,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
|
||||
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
|
||||
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
|
||||
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
const C1GridDesc_G_M_N& c1_grid_desc_g_m_n,
|
||||
const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n)
|
||||
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
|
||||
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
|
||||
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
|
||||
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
|
||||
c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n),
|
||||
d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -314,32 +357,42 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d0_idx) const
|
||||
{
|
||||
return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0));
|
||||
}
|
||||
|
||||
private:
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
D0sDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
D0sGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
@@ -396,8 +449,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
CDataType* p_c_grid,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -406,44 +459,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>&
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0s_grid_{},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
a_grid_desc_g_m_k_{
|
||||
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
|
||||
b_grid_desc_g_n_k_{
|
||||
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
|
||||
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
|
||||
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
|
||||
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
|
||||
c_gs_ms_gemm1ns_strides)},
|
||||
d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)},
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
acc_element_op_{acc_element_op},
|
||||
c0de_element_op_{c0de_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
c1de_element_op_{c1de_element_op},
|
||||
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
|
||||
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
|
||||
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
|
||||
@@ -457,27 +514,39 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
|
||||
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
|
||||
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
|
||||
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{
|
||||
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_}
|
||||
batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_g_m_k_,
|
||||
b_grid_desc_g_n_k_,
|
||||
b1_grid_desc_g_n_k_,
|
||||
c1_grid_desc_g_m_n_,
|
||||
d0s_grid_desc_g_m_n_}
|
||||
{
|
||||
// TODO ANT: implement bias addition
|
||||
ignore = p_acc0_biases;
|
||||
ignore = p_acc1_biases;
|
||||
ignore = acc0_biases_gs_ms_ns_lengths;
|
||||
ignore = acc0_biases_gs_ms_ns_strides;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
|
||||
ignore = acc1_biases_gs_ms_gemm1ns_strides;
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_acc0_biases[i]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
b1_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
c1_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N(
|
||||
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)};
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
|
||||
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
d0s_grid_desc_m_n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,9 +561,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
|
||||
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
|
||||
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I1) << ", "
|
||||
<< c1_grid_desc_g_m_n_.GetLength(I2) << '\n';
|
||||
}
|
||||
|
||||
// pointers
|
||||
@@ -502,18 +571,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
|
||||
|
||||
// tensor descriptor
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
AGridDesc_G_M_K a_grid_desc_g_m_k_;
|
||||
BGridDesc_G_N_K b_grid_desc_g_n_k_;
|
||||
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
|
||||
CGridDesc_G_M_N c_grid_desc_g_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
|
||||
D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_;
|
||||
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
|
||||
|
||||
// block-to-c-tile map
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
@@ -521,9 +595,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
C0DEElementwiseOperation c0de_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
C1DEElementwiseOperation c1de_element_op_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
@@ -552,7 +626,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K =
|
||||
@@ -565,15 +639,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
typename GridwiseGemm::D0sGridPointer,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
C0DEElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1DEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
@@ -588,15 +664,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0s_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.acc_element_op_,
|
||||
arg.c0de_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1de_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
@@ -646,9 +724,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
// TODO ANT: Check if tensor specialization & strides mismatch
|
||||
|
||||
// Check if C permute dimension matches GEMM + GEMM shape
|
||||
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
|
||||
const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
|
||||
const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
|
||||
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
|
||||
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
|
||||
|
||||
@@ -698,7 +776,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.c1_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
@@ -713,8 +791,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const BDataType* p_b,
|
||||
const B1DataType* p_b1,
|
||||
CDataType* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -723,17 +801,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
C1DEElementwiseOperation c1de_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -755,9 +833,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op};
|
||||
c1de_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -769,8 +847,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const void* p_b,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
const std::array<void*, NumAcc0Bias> p_acc0_biases,
|
||||
const std::array<void*, NumAcc1Bias> p_acc1_biases,
|
||||
const std::array<void*, NumD0Tensor> p_acc0_biases,
|
||||
const std::array<void*, NumD1Tensor> p_acc1_biases,
|
||||
const std::vector<index_t>& a_gs_ms_ks_lengths,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides,
|
||||
const std::vector<index_t>& b_gs_ns_ks_lengths,
|
||||
@@ -779,17 +857,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
|
||||
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
|
||||
const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
|
||||
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
|
||||
const std::array<std::vector<ck::index_t>, NumD1Tensor>
|
||||
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
C0DEElementwiseOperation c0de_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
C1DEElementwiseOperation c1de_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -811,9 +889,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
acc1_biases_gs_ms_gemm1ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
acc_element_op,
|
||||
c0de_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
c1de_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl
|
||||
return (false);
|
||||
}
|
||||
|
||||
if(p_arg_->x_lds_size_ >= 65536)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
@@ -670,6 +670,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
|
||||
}
|
||||
if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
|
||||
arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
|
||||
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
|
||||
|
||||
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
|
||||
|
||||
@@ -941,7 +944,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.gemm_e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -1057,7 +1064,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
<< GemmKPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PostShuffleThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< PostShuffleThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< LayernormThreadSliceSize_M
|
||||
<< ">"
|
||||
<< " LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -0,0 +1,654 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_Wmma_CShuffle;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
#ifdef ENABLE_COLMAJOR
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
#endif
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(M, PadM)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELayout_>
|
||||
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
|
||||
{
|
||||
const auto e_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
|
||||
const std::array<index_t, NumDTensor>& Ns,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(Ms[i], Ns[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
std::array<const void*, NumDTensor> p_ds_grid,
|
||||
void* p_e_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
|
||||
});
|
||||
e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// Pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseOp::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// Block to Tile mapping
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// Idle
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
// ElementwiseOp
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleD_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,850 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
// Assume:
|
||||
// AK1 == BK1
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
|
||||
ds_g_n_k_wos_strides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK1 = K1;
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK1 = K1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
// using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
});
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_ =
|
||||
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides);
|
||||
|
||||
// populate desc for Ds/E
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseOp::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check ConvolutionForwardSpecialization
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
|
||||
const index_t ConvStride = arg.conv_filter_strides_[i];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of A
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
{
|
||||
const index_t C = arg.a_g_n_c_wis_lengths_[2];
|
||||
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
|
||||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
|
||||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
|
||||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
|
||||
{
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of Ds
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
// FIXME: layout
|
||||
if constexpr(is_same_v<DLayout, ctc::G_NW_K> || is_same_v<DLayout, ctc::G_NHW_K> ||
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
|
||||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
|
||||
is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
|
||||
|
||||
if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
{
|
||||
const index_t K = arg.e_g_n_k_wos_lengths_[2];
|
||||
|
||||
if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< getConvForwardSpecializationString(ConvForwardSpecialization)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,46 +10,11 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <typename GridwiseReduction,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename GridDesc_M_K>
|
||||
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
|
||||
const GridDesc_M_K gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K beta_grid_desc_m_k,
|
||||
const GridDesc_M_K y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
{
|
||||
GridwiseReduction::Run(x_grid_desc_m_k,
|
||||
gamma_grid_desc_m_k,
|
||||
beta_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x_global,
|
||||
p_gamma_global,
|
||||
p_beta_global,
|
||||
p_y_global,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -58,9 +23,9 @@ namespace device {
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename YElementwiseOperation,
|
||||
index_t Rank,
|
||||
index_t NumReduceDim,
|
||||
index_t BlockSize,
|
||||
@@ -74,16 +39,18 @@ template <typename XDataType,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorSize>
|
||||
index_t YDstVectorSize,
|
||||
bool UseWelford = true>
|
||||
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim>
|
||||
{
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
|
||||
static_assert(
|
||||
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
|
||||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
|
||||
@@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
|
||||
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
|
||||
|
||||
using GridwiseReduceLayernormGeneric =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
false>;
|
||||
using GridwiseNormalizationSweepOnce =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
true>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::vector<index_t> lengths,
|
||||
@@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
const std::vector<index_t> betaStrides,
|
||||
const std::vector<index_t> yStrides,
|
||||
const std::vector<index_t> reduceDims,
|
||||
AccElementwiseOperation acc_elementwise_op,
|
||||
YElementwiseOperation y_elementwise_op,
|
||||
double epsilon,
|
||||
const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
@@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
p_gamma_(p_gamma),
|
||||
p_beta_(p_beta),
|
||||
p_y_(p_y),
|
||||
acc_elementwise_op_(acc_elementwise_op)
|
||||
y_elementwise_op_(y_elementwise_op)
|
||||
{
|
||||
epsilon_ = static_cast<AccDataType>(epsilon);
|
||||
epsilon_ = static_cast<ComputeDataType>(epsilon);
|
||||
|
||||
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
|
||||
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
|
||||
@@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
|
||||
}
|
||||
|
||||
AccDataType epsilon_;
|
||||
ComputeDataType epsilon_;
|
||||
|
||||
const XDataType* p_x_;
|
||||
const GammaDataType* p_gamma_;
|
||||
@@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
std::vector<index_t> betaStrides_;
|
||||
std::vector<index_t> yStrides_;
|
||||
|
||||
AccElementwiseOperation acc_elementwise_op_;
|
||||
YElementwiseOperation y_elementwise_op_;
|
||||
|
||||
int blkGroupSize_;
|
||||
int numBlockTileIteration_;
|
||||
@@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel_main = arg.isSweeponce_
|
||||
? kernel_normalization<GridwiseNormalizationSweepOnce,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K>
|
||||
: kernel_normalization<GridwiseReduceLayernormGeneric,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
AccDataType,
|
||||
AccElementwiseOperation,
|
||||
GridDesc_M_K>;
|
||||
auto kernel_main = NormalizationKernelSelector<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XYSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
XYSrcVectorDim,
|
||||
YDstVectorSize,
|
||||
UseWelford>(arg.isSweeponce_);
|
||||
|
||||
float avg_time = 0;
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
@@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
arg.p_gamma_,
|
||||
arg.p_beta_,
|
||||
arg.p_y_,
|
||||
arg.acc_elementwise_op_);
|
||||
arg.y_elementwise_op_);
|
||||
|
||||
return (avg_time);
|
||||
};
|
||||
@@ -429,7 +355,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
void* p_y,
|
||||
void* p_saveMean,
|
||||
void* p_saveInvVar,
|
||||
AccElementwiseOperation acc_elementwise_op) override
|
||||
YElementwiseOperation y_elementwise_op) override
|
||||
{
|
||||
// TODO
|
||||
// Optional cache of the intermediate results (mean and InvVariance) during the
|
||||
@@ -443,7 +369,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
betaStrides,
|
||||
yStrides,
|
||||
reduceDims,
|
||||
acc_elementwise_op,
|
||||
y_elementwise_op,
|
||||
epsilon,
|
||||
static_cast<const XDataType*>(p_x),
|
||||
static_cast<const GammaDataType*>(p_gamma),
|
||||
@@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceNormalizationImpl<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
|
||||
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
|
||||
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
|
||||
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
@@ -49,6 +49,14 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
@@ -67,6 +75,30 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void
|
||||
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
y = scale_ * x0 + ck::type_convert<float>(x1);
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
struct Subtract
|
||||
{
|
||||
template <typename T>
|
||||
@@ -118,6 +150,13 @@ struct Bilinear
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
|
||||
|
||||
@@ -95,6 +95,12 @@ struct Scale
|
||||
y = scale_ * x;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
|
||||
{
|
||||
y = scale_ * x;
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -879,14 +879,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
}
|
||||
} // end gemm1
|
||||
|
||||
// workaround compiler issue; see ck/ck.hpp
|
||||
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
|
||||
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
|
||||
Gemm1NPerBlock == 128)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -574,4 +574,546 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t K1Value,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
typename M11N11ThreadClusterM110Xs,
|
||||
typename M11N11ThreadClusterN110Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseGemmDl_bkm_bkn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_b_k0_m_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_b_k0_n_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1,
|
||||
const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
|
||||
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
|
||||
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
|
||||
const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_b_k0_n_k1.GetLength(I1) &&
|
||||
K1 == a_grid_desc_b_k0_m_k1.GetLength(I3) &&
|
||||
K1 == b_grid_desc_b_k0_n_k1.GetLength(I3)) &&
|
||||
KBatch == b_grid_desc_b_k0_n_k1.GetLength(I0) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1)
|
||||
{
|
||||
const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0);
|
||||
const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1);
|
||||
const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2);
|
||||
|
||||
const auto M1 = Number<MPerBlock>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_grid_desc_b_k0_m0_m1_k1 = transform_tensor_descriptor(
|
||||
a_grid_desc_b_k0_m_k1,
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_grid_desc_b_k0_m0_m1_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1)
|
||||
{
|
||||
const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(I0);
|
||||
const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2);
|
||||
|
||||
const auto N1 = Number<NPerBlock>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_grid_desc_b_k0_n0_n1_k1 = transform_tensor_descriptor(
|
||||
b_grid_desc_b_k0_n_k1,
|
||||
make_tuple(make_pass_through_transform(KBatch),
|
||||
make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_grid_desc_b_k0_n0_n1_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_grid_desc_m0_m10_m11_n0_n10_n11;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
|
||||
const CGridDesc_M_N& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch)
|
||||
{
|
||||
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_m_n_grid_desc, M01, N01, KBatch);
|
||||
}
|
||||
|
||||
using AGridDesc_B_K0_M0_M1_K1 =
|
||||
decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{}));
|
||||
using BGridDesc_B_K0_N0_N1_K1 =
|
||||
decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{}));
|
||||
using CGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_B_K0_M0_M1_K1& a_grid_desc_b_k0_m0_m1_k1,
|
||||
const BGridDesc_B_K0_N0_N1_K1& b_grid_desc_b_k0_n0_n1_k1,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t k_batch_id = block_work_idx[I0];
|
||||
|
||||
if(!c_block_cluster_adaptor.ValidCTileIndex(
|
||||
make_tuple(block_work_idx[I1], block_work_idx[I2]),
|
||||
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
|
||||
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
|
||||
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(I1, Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
|
||||
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
remove_reference_t<decltype(a_grid_desc_b_k0_m0_m1_k1)>,
|
||||
decltype(a_block_desc_b_k0_m0_m1_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_grid_desc_b_k0_m0_m1_k1,
|
||||
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0, 0),
|
||||
a_block_desc_b_k0_m0_m1_k1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<1, K0PerBlock, 1, NPerBlock, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
remove_reference_t<decltype(b_grid_desc_b_k0_n0_n1_k1)>,
|
||||
decltype(b_block_desc_b_k0_n0_n1_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_grid_desc_b_k0_n0_n1_k1,
|
||||
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0, 0),
|
||||
b_block_desc_b_k0_n0_n1_k1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(I1);
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
|
||||
a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1,
|
||||
a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * K0PerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
|
||||
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
make_multi_index(m_block_data_idx_on_grid,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
n_block_data_idx_on_grid,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}}
|
||||
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,937 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseOp,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle(
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseOp,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_d_wmma_cshuffle(
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
|
||||
// printf("entry kernel launch");
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
// printf("before compute_ptr_offset call");
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
// printf("before allocate pointer d");
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
// printf("before entry");
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_2_etile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_etile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseOp,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_mupltipe_d_wmma_cshuffle(
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx1100__))
|
||||
}
|
||||
|
||||
template < // DataType Family
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
// InMemory Data Descriptor
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename DsGridDesc_M_N,
|
||||
typename EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
// Tiling Family
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
// ThreadCluster Family
|
||||
index_t BlockSize,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0perblock_mperblock_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0perblock_nperblock_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
|
||||
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
|
||||
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerWmma)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
|
||||
N == ds_grid_desc_m_n[i].GetLength(I1));
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
// E desc for destination in blockwise copy
|
||||
template <typename EGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_& e_grid_desc_m_n)
|
||||
{
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// Ds desc for source in blockwise copy
|
||||
template <typename DsGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const EGridDesc_M_N& e_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
|
||||
e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>;
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
// printf("safe entry");
|
||||
// clang-format off
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockIdx.x -> [BlockId.m, BlockId.n]
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{ return; }
|
||||
|
||||
// Store BlockId into SGPR
|
||||
const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
constexpr auto max_lds_align = K1;
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
|
||||
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
|
||||
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
|
||||
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
|
||||
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
|
||||
/* typename SrcData, */ ADataType,
|
||||
/* typename DstData, */ ADataType,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
|
||||
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
|
||||
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
|
||||
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
|
||||
/* index_t DstVectorDim, */ 2,
|
||||
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
|
||||
/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
|
||||
/* index_t SrcScalarStrideInVector, */ 1,
|
||||
/* index_t DstScalarStrideInVector, */ 1,
|
||||
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
/*******************************************************************************/
|
||||
// GEMM
|
||||
constexpr auto WmmaK = 16;
|
||||
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
|
||||
// Prepare Register for C matrix
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
/*******************************************************************************/
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<ADataType*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<BDataType*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
|
||||
|
||||
// Shift Per SUB_K
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
/*******************************************************************************/
|
||||
//printf("safe 1");
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1, // vector write pixel
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
cde_shuffle_block_copy_lds_to_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
I0,
|
||||
cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -4,9 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/reduction_common.hpp"
|
||||
|
||||
#include "ck/utility/reduction_operator.hpp"
|
||||
#include "ck/utility/reduction_functions_accumulate.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
@@ -19,8 +18,8 @@ template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename ComputeDataType,
|
||||
typename YElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
@@ -46,6 +45,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(XSrcVectorSize == YDstVectorSize);
|
||||
static_assert(XSrcVectorSize == GammaSrcVectorSize);
|
||||
static_assert(XSrcVectorSize == BetaSrcVectorSize);
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
|
||||
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType,
|
||||
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder,
|
||||
reduce::Add,
|
||||
true>;
|
||||
|
||||
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
|
||||
using ThreadwiseSumReduce = ThreadwiseReduction<ComputeDataType,
|
||||
ThreadReduceSrcDesc_M_K,
|
||||
ThreadReduceDstDesc_M,
|
||||
reduce::Add,
|
||||
@@ -81,64 +88,70 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
|
||||
|
||||
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
|
||||
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
const GridDesc_M_K& gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K& beta_grid_desc_m_k,
|
||||
const GridDesc_M_K& y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
ComputeDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
const YElementwiseOperation y_elementwise_op)
|
||||
{
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
num_k_block_tile_iteration = 1;
|
||||
}
|
||||
|
||||
// LDS
|
||||
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
|
||||
|
||||
auto reduce_work_buf =
|
||||
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
x_thread_buf;
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
gamma_thread_buf;
|
||||
auto x_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * XSrcVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<ThreadBufferNumber>{});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>& beta_thread_buf = gamma_thread_buf;
|
||||
auto gamma_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * GammaSrcVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<ThreadBufferNumber>{});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
|
||||
y_thread_buf;
|
||||
auto& beta_thread_buf = gamma_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * KThreadSliceSize,
|
||||
true>& x_square_thread_buf = y_thread_buf;
|
||||
auto y_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * YDstVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<ThreadBufferNumber>{});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
|
||||
auto& x_square_thread_buf = y_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
mean_square_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_thread_buf =
|
||||
mean_square_thread_buf;
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
|
||||
});
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
|
||||
var_thread_buf = mean_square_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
@@ -149,12 +162,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -166,11 +175,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
x_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
thread_k_cluster_id * XSrcVectorSize));
|
||||
|
||||
auto threadwise_gamma_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -182,11 +191,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
gamma_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
thread_k_cluster_id * GammaSrcVectorSize));
|
||||
|
||||
auto threadwise_beta_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -198,14 +207,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
beta_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize));
|
||||
thread_k_cluster_id * BetaSrcVectorSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
GridDesc_M_K,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
YDstVectorDim,
|
||||
@@ -216,13 +225,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
y_grid_desc_m_k,
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * KThreadSliceSize),
|
||||
acc_elementwise_op);
|
||||
thread_k_cluster_id * YDstVectorSize),
|
||||
y_elementwise_op);
|
||||
|
||||
// Copy x from Cache
|
||||
// one pass: fwd, second pass: bwd
|
||||
constexpr auto thread_copy_fwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k =
|
||||
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
|
||||
|
||||
@@ -239,121 +245,260 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
// FIXME: Should not hack the transform from deviceOP
|
||||
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
|
||||
|
||||
index_t reducedTiles = 0;
|
||||
do
|
||||
{
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(Number<offset_m_k>{}) =
|
||||
x_thread_buf(Number<offset_m_k>{}) * x_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf);
|
||||
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf);
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
|
||||
++reducedTiles;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
|
||||
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
|
||||
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
|
||||
|
||||
// var(x) = E[x^2] - E[x]^2
|
||||
var_thread_buf(I) =
|
||||
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
|
||||
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
|
||||
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
|
||||
});
|
||||
|
||||
// y = (x - E[x]) / sqrt(var[x] + epsilon)
|
||||
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
reducedTiles = 0;
|
||||
do
|
||||
// Separate sweep once and sweep twice pipeline
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf);
|
||||
x_thread_buf(i));
|
||||
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf(i));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
|
||||
ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
{
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
|
||||
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
|
||||
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
|
||||
|
||||
// var(x) = E[x^2] - E[x]^2
|
||||
var_thread_buf(I) =
|
||||
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
divisor;
|
||||
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf(i));
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf(i),
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
} // end of sweep once
|
||||
else
|
||||
{
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
|
||||
ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
|
||||
});
|
||||
}
|
||||
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf);
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
|
||||
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
|
||||
|
||||
// normalize
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
|
||||
sqrt(var_thread_buf(iM) + epsilon);
|
||||
block_sync_lds();
|
||||
|
||||
// gamma
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
|
||||
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
|
||||
|
||||
// var(x) = E[x^2] - E[x]^2
|
||||
var_thread_buf(I) =
|
||||
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
|
||||
});
|
||||
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
// beta
|
||||
y_thread_buf(Number<offset_m_k>{}) =
|
||||
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf,
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
auto thread_copy_tail_m_k =
|
||||
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
++reducedTiles;
|
||||
} while(reducedTiles < num_k_block_tile_iteration);
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf(i));
|
||||
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
divisor;
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf(i));
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf(i),
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
}
|
||||
} // end of sweep twice
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <typename GridwiseReduction,
|
||||
typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename ComputeDataType,
|
||||
typename YElementwiseOperation,
|
||||
typename GridDesc_M_K>
|
||||
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
|
||||
const GridDesc_M_K gamma_grid_desc_m_k,
|
||||
const GridDesc_M_K beta_grid_desc_m_k,
|
||||
const GridDesc_M_K y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
ComputeDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const YElementwiseOperation y_elementwise_op)
|
||||
{
|
||||
GridwiseReduction::Run(x_grid_desc_m_k,
|
||||
gamma_grid_desc_m_k,
|
||||
beta_grid_desc_m_k,
|
||||
y_grid_desc_m_k,
|
||||
num_k_block_tile_iteration,
|
||||
epsilon,
|
||||
p_x_global,
|
||||
p_gamma_global,
|
||||
p_beta_global,
|
||||
p_y_global,
|
||||
y_elementwise_op);
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename ComputeDataType,
|
||||
typename YElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
index_t MThreadSliceSize,
|
||||
index_t KThreadSliceSize,
|
||||
index_t XSrcVectorDim,
|
||||
index_t XSrcVectorSize,
|
||||
index_t GammaSrcVectorDim,
|
||||
index_t GammaSrcVectorSize,
|
||||
index_t BetaSrcVectorDim,
|
||||
index_t BetaSrcVectorSize,
|
||||
index_t YDstVectorDim,
|
||||
index_t YDstVectorSize,
|
||||
bool UseWelford>
|
||||
auto NormalizationKernelSelector(bool isSweepOnce)
|
||||
{
|
||||
using GridwiseNormalizationGenericNaive =
|
||||
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
YDstVectorDim,
|
||||
YDstVectorSize,
|
||||
false>;
|
||||
using GridwiseNormalizationSweepOnceNaive =
|
||||
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
YDstVectorDim,
|
||||
YDstVectorSize,
|
||||
true>;
|
||||
using GridwiseNormalizationGenericWelford =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
YDstVectorDim,
|
||||
YDstVectorSize,
|
||||
false>;
|
||||
using GridwiseNormalizationSweepOnceWelford =
|
||||
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K,
|
||||
BlockSize,
|
||||
MThreadClusterSize,
|
||||
KThreadClusterSize,
|
||||
MThreadSliceSize,
|
||||
KThreadSliceSize,
|
||||
XSrcVectorDim,
|
||||
XSrcVectorSize,
|
||||
GammaSrcVectorDim,
|
||||
GammaSrcVectorSize,
|
||||
BetaSrcVectorDim,
|
||||
BetaSrcVectorSize,
|
||||
YDstVectorDim,
|
||||
YDstVectorSize,
|
||||
true>;
|
||||
|
||||
if constexpr(UseWelford)
|
||||
{
|
||||
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceWelford,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K>
|
||||
: kernel_normalization<GridwiseNormalizationGenericWelford,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K>;
|
||||
}
|
||||
else
|
||||
{
|
||||
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceNaive,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K>
|
||||
: kernel_normalization<GridwiseNormalizationGenericNaive,
|
||||
XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
YDataType,
|
||||
ComputeDataType,
|
||||
YElementwiseOperation,
|
||||
GridDesc_M_K>;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -16,8 +16,8 @@ template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
typename AccElementwiseOperation,
|
||||
typename ComputeDataType,
|
||||
typename YElementwiseOperation,
|
||||
typename GridDesc_M_K,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
@@ -43,6 +43,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
|
||||
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
|
||||
|
||||
static_assert(XSrcVectorSize == YDstVectorSize);
|
||||
static_assert(XSrcVectorSize == GammaSrcVectorSize);
|
||||
static_assert(XSrcVectorSize == BetaSrcVectorSize);
|
||||
|
||||
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
|
||||
|
||||
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
|
||||
@@ -56,15 +60,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
|
||||
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
ThreadwiseWelford<ComputeDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
using BlockwiseWelford = BlockwiseWelford<AccDataType,
|
||||
using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
|
||||
BlockSize,
|
||||
ThreadClusterLengths_M_K,
|
||||
ThreadClusterArrangeOrder>;
|
||||
@@ -77,10 +85,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
|
||||
|
||||
static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
|
||||
|
||||
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
|
||||
int thread_k_cluster_id)
|
||||
@@ -93,7 +98,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
|
||||
if(kPerBlockTail > 0)
|
||||
{
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
int thread_max_len =
|
||||
(thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
|
||||
int delta = thread_max_len - kPerBlockTail;
|
||||
@@ -110,59 +115,41 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
const GridDesc_M_K& beta_grid_desc_m_k,
|
||||
const GridDesc_M_K& y_grid_desc_m_k,
|
||||
index_t num_k_block_tile_iteration,
|
||||
AccDataType epsilon,
|
||||
ComputeDataType epsilon,
|
||||
const XDataType* const __restrict__ p_x_global,
|
||||
const GammaDataType* const __restrict__ p_gamma_global,
|
||||
const BetaDataType* const __restrict__ p_beta_global,
|
||||
YDataType* const __restrict__ p_y_global,
|
||||
const AccElementwiseOperation acc_elementwise_op)
|
||||
const YElementwiseOperation y_elementwise_op)
|
||||
{
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
num_k_block_tile_iteration = 1;
|
||||
}
|
||||
|
||||
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
|
||||
|
||||
auto x_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * XSrcVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<XThreadBufferNumber>{});
|
||||
Number<ThreadBufferNumber>{});
|
||||
|
||||
auto gamma_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
MThreadSliceSize * GammaSrcVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<GammaThreadBufferNumber>{});
|
||||
Number<ThreadBufferNumber>{});
|
||||
|
||||
auto beta_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * BetaSrcVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<BetaThreadBufferNumber>{});
|
||||
auto& beta_thread_buf = gamma_thread_buf;
|
||||
auto& y_thread_buf = x_thread_buf;
|
||||
|
||||
auto y_thread_buf = generate_tuple(
|
||||
[&](auto) {
|
||||
return StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MThreadSliceSize * YDstVectorSize,
|
||||
true>{};
|
||||
},
|
||||
Number<YThreadBufferNumber>{});
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> var_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
|
||||
var_thread_buf;
|
||||
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
@@ -173,12 +160,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_k_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
|
||||
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
|
||||
|
||||
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -194,7 +177,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
|
||||
auto threadwise_gamma_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -210,7 +193,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
|
||||
auto threadwise_beta_load =
|
||||
ThreadwiseTensorSliceTransfer_v2<BetaDataType,
|
||||
AccDataType,
|
||||
ComputeDataType,
|
||||
GridDesc_M_K,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
ThreadBufferLengths_M_K,
|
||||
@@ -225,11 +208,11 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
thread_k_cluster_id * BetaSrcVectorSize));
|
||||
|
||||
auto threadwise_y_store =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
|
||||
YDataType,
|
||||
decltype(thread_buffer_desc_m_k),
|
||||
GridDesc_M_K,
|
||||
AccElementwiseOperation,
|
||||
YElementwiseOperation,
|
||||
ThreadBufferLengths_M_K,
|
||||
ThreadBufferDimAccessOrder,
|
||||
YDstVectorDim,
|
||||
@@ -241,7 +224,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
make_multi_index(block_global_id * M_BlockTileSize +
|
||||
thread_m_cluster_id * MThreadSliceSize,
|
||||
thread_k_cluster_id * YDstVectorSize),
|
||||
acc_elementwise_op);
|
||||
y_elementwise_op);
|
||||
|
||||
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
|
||||
constexpr auto thread_copy_bwd_step_m_k =
|
||||
@@ -260,67 +243,47 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
|
||||
});
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
// Separate sweep once and sweep twice pipeline
|
||||
if constexpr(SweepOnce)
|
||||
{
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
int count = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
|
||||
});
|
||||
|
||||
auto thread_copy_tail_m_k =
|
||||
(num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
if constexpr(!SweepOnce)
|
||||
{
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf(i));
|
||||
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
{
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
int count = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
@@ -330,7 +293,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
divisor;
|
||||
|
||||
// gamma
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
@@ -338,18 +301,20 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf(i));
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
@@ -362,22 +327,134 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf(i),
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
|
||||
if constexpr(i != ThreadBufferNumber - 1)
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
} // end of sweep once
|
||||
else
|
||||
{
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf);
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
|
||||
int count = threadwise_welford.cur_count_;
|
||||
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
|
||||
});
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
|
||||
}
|
||||
auto thread_copy_tail_m_k =
|
||||
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
|
||||
|
||||
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
|
||||
{
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_x_load.Run(x_grid_desc_m_k,
|
||||
x_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
|
||||
gamma_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf(i));
|
||||
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
divisor;
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
beta_global_val_buf,
|
||||
thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
beta_thread_buf(i));
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
make_tuple(I0, I0),
|
||||
y_thread_buf(i),
|
||||
y_grid_desc_m_k,
|
||||
y_global_val_buf);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
|
||||
2 * thread_copy_bwd_step_m_k);
|
||||
}
|
||||
} // end of sweep twice
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -185,9 +185,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto output_index_) -> auto& {
|
||||
return acc_thread_buf(Number<register_offset>{});
|
||||
},
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
|
||||
@@ -83,6 +83,11 @@ static inline __host__ bool isnan(int4_t x)
|
||||
};
|
||||
#endif
|
||||
|
||||
static inline __host__ half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(std::sqrt(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
|
||||
@@ -158,9 +163,14 @@ static inline __device__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
|
||||
static inline __device__ half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
|
||||
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
@@ -89,8 +89,10 @@ using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user