mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
This commit is contained in:
@@ -23,20 +23,7 @@ trigger:
|
||||
- Jenkinsfile
|
||||
- LICENSE
|
||||
|
||||
pr:
|
||||
autoCancel: true
|
||||
branches:
|
||||
include:
|
||||
- develop
|
||||
paths:
|
||||
exclude:
|
||||
- .github
|
||||
- docs
|
||||
- '.*.y*ml'
|
||||
- '*.md'
|
||||
- Jenkinsfile
|
||||
- LICENSE
|
||||
drafts: false
|
||||
pr: none
|
||||
|
||||
jobs:
|
||||
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
|
||||
|
||||
0
.pre-commit-config.yaml
Normal file → Executable file
0
.pre-commit-config.yaml
Normal file → Executable file
@@ -117,7 +117,7 @@ else()
|
||||
add_definitions(-DPROFILER_ONLY)
|
||||
set(GPU_TARGETS "" CACHE STRING "" FORCE)
|
||||
if(GPU_TARGETS)
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11")
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
endif()
|
||||
if(GPU_ARCH MATCHES "gfx90")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
|
||||
@@ -127,8 +127,10 @@ else()
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
|
||||
elseif(GPU_ARCH MATCHES "gfx11")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
|
||||
elseif(GPU_ARCH MATCHES "gfx12")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
|
||||
else()
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11")
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
endif()
|
||||
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
endif()
|
||||
|
||||
10
Dockerfile
10
Dockerfile
@@ -23,11 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.2" ]; then \
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
|
||||
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
|
||||
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
|
||||
elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc2" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=1736298; \
|
||||
elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc1" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.2-20.04-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.2-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-8 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=1794148; \
|
||||
fi
|
||||
|
||||
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
|
||||
|
||||
8
Jenkinsfile
vendored
8
Jenkinsfile
vendored
@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
triggers {
|
||||
parameterizedCron(CRON_SETTINGS)
|
||||
}
|
||||
options {
|
||||
parallelsAlwaysFailFast()
|
||||
}
|
||||
@@ -888,10 +886,10 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a" \
|
||||
-DGPU_TARGETS="gfx1100;gfx90a" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
|
||||
@@ -7,19 +7,23 @@
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
// __gfx9__ defined in the above header via ck.hpp
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
@@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_xdl_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
|
||||
@@ -213,3 +225,4 @@ int main(int argc, char* argv[])
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -7,18 +7,21 @@
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
// __gfx9__ defined in the above header via ck.hpp
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
@@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_xdl_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
@@ -305,3 +316,4 @@ int main(int argc, char* argv[])
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -66,7 +66,7 @@ else()
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
-Wno-reserved-identifier
|
||||
-Werror
|
||||
-Werror
|
||||
-Wno-option-ignored
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(composable_kernel_host)
|
||||
project(composable_kernel_host LANGUAGES CXX HIP)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
@@ -12,24 +12,38 @@ find_package(ROCM)
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMTest)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
find_package(hip)
|
||||
## HIP
|
||||
set(CMAKE_HIP_PLATFORM amd)
|
||||
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
|
||||
set(CMAKE_HIP_EXTENSIONS ON)
|
||||
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
|
||||
|
||||
# add include directories
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${HIP_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
||||
include(Embed)
|
||||
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
add_definitions(-std=c++17)
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
# TODO: Use object library
|
||||
add_library(ck_host STATIC ${SOURCES})
|
||||
target_link_libraries(ck_host PRIVATE ck_headers)
|
||||
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
|
||||
@@ -5,24 +5,27 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
|
||||
using ck::host::Transform;
|
||||
|
||||
struct Emitters
|
||||
{
|
||||
// retrieve the hard-coded instances provided, template them, and then store them in a map
|
||||
std::unordered_map<std::string, std::function<std::vector<std::string>()>> m;
|
||||
|
||||
template <class T>
|
||||
void Register(const std::string& name)
|
||||
void Register(const std::string& name, const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
m[name] = [] {
|
||||
auto configs = T::CreateOperations();
|
||||
m[name] = [&] {
|
||||
auto configs = T::CreateOperations(prologue, epilogue);
|
||||
|
||||
return Transform(configs, [](const auto& ops) { return ToTuple(ops); });
|
||||
};
|
||||
}
|
||||
|
||||
// takes in an operation instance and uses it to substitute the correct values into the template
|
||||
template <class T>
|
||||
static std::string ToTuple(const T& ops)
|
||||
{
|
||||
@@ -31,6 +34,7 @@ struct Emitters
|
||||
return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">";
|
||||
}
|
||||
|
||||
// Join together all the strings in the map
|
||||
std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); }
|
||||
|
||||
std::vector<std::string> List() const
|
||||
@@ -43,9 +47,38 @@ int main(int argc, const char* argv[])
|
||||
{
|
||||
std::string prog = argv[0];
|
||||
std::vector<std::string> args(argv + 1, argv + argc);
|
||||
|
||||
// Specify problem type and problem size
|
||||
ck::host::device_gemm_multiple_d::Problem prob;
|
||||
prob.M = 1024;
|
||||
prob.N = 1024;
|
||||
prob.K = 1024;
|
||||
|
||||
// user provided fusion
|
||||
std::string prologue = "";
|
||||
std::string epilogue = R"(
|
||||
struct Epilogue
|
||||
{
|
||||
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};)";
|
||||
|
||||
// Load in operations into the Register
|
||||
Emitters e;
|
||||
e.Register<ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle>(
|
||||
"DeviceGemmMultipleD_Xdl_CShuffle");
|
||||
"DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue);
|
||||
|
||||
if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) {
|
||||
return arg == "-h" or arg == "--help";
|
||||
@@ -64,6 +97,7 @@ int main(int argc, const char* argv[])
|
||||
return 0;
|
||||
}
|
||||
|
||||
// print out all the instances for the operation that was chosen at the command line
|
||||
for(auto name : args)
|
||||
std::cout << e.Emit(name) << std::endl;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -14,10 +14,15 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
// defines all values need for an instance of fwd conv
|
||||
struct Operation_Xdl_CShuffle
|
||||
{
|
||||
static std::vector<std::vector<Operation_Xdl_CShuffle>> CreateOperations();
|
||||
static std::vector<Operation_Xdl_CShuffle> CreateOperations(const Problem& prob);
|
||||
// returns a vector of instances, only given fusion operators: will use default problem spec
|
||||
static std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
CreateOperations(const std::string& prologue, const std::string& epilogue);
|
||||
// returns a vector of instances, given a problem spec and fusion operators
|
||||
static std::vector<Operation_Xdl_CShuffle>
|
||||
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
|
||||
TensorDesc A{};
|
||||
TensorDesc B{};
|
||||
DataType acc = DataType::Float;
|
||||
@@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle
|
||||
std::string a_elem_op = PassThrough;
|
||||
std::string b_elem_op = PassThrough;
|
||||
std::string cde_elem_op = Bilinear;
|
||||
std::string prologue = "";
|
||||
std::string epilogue = "";
|
||||
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
// tuning parameters
|
||||
operation::TileDesc tile_desc{};
|
||||
operation::BlockTransferDesc a_block_transfer{};
|
||||
operation::BlockTransferDesc b_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
|
||||
// functions to update fusion operators if provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
void update_epilogue(const std::string& epilogue);
|
||||
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
|
||||
// returns a templated instance
|
||||
Solution ToSolution() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,11 +12,14 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
// dimensions for GEMM operation
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
// layouts for tensors
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransE = false;
|
||||
@@ -29,9 +32,13 @@ struct Problem
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string CDEElementOp = PassThrough;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
std::vector<Solution> GetSolutions(const std::string& arch) const;
|
||||
// returns a list of instances based on the problem spec and provided fusion operations
|
||||
std::vector<Solution> GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const;
|
||||
};
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/operation/gemm.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace conv {
|
||||
|
||||
// defines the values needed for an instance of forward convolution and functions to return
|
||||
// (templated) instances
|
||||
struct Operation_Conv_Fwd_Xdl_Cshuffle
|
||||
{
|
||||
// returns a vector of instances given the fusion operations, uses default values for problem
|
||||
// spec
|
||||
static std::vector<Operation_Conv_Fwd_Xdl_Cshuffle>
|
||||
CreateOperations(const std::string& prologue, const std::string& epilogue);
|
||||
// returns a vector of instances, provided with a problem spec and fusion operations
|
||||
static std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> CreateOperations(
|
||||
const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue);
|
||||
std::size_t NumDim;
|
||||
TensorDesc A{};
|
||||
TensorDesc B{};
|
||||
DataType acc = DataType::Float;
|
||||
DataType cs_type = DataType::Half;
|
||||
std::vector<TensorDesc> Ds = {};
|
||||
TensorDesc E{};
|
||||
std::string a_elem_op = PassThrough;
|
||||
std::string b_elem_op = PassThrough;
|
||||
std::string cde_elem_op = PassThrough;
|
||||
std::string prologue = "";
|
||||
std::string epilogue = "";
|
||||
std::string conv_specialization =
|
||||
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Default";
|
||||
std::string gemm_specialization =
|
||||
"ck::tensor_operation::device::GemmSpecialization::MNKPadding";
|
||||
// tuning parameters
|
||||
operation::TileDesc tile_desc{};
|
||||
operation::BlockTransferDesc a_block_transfer{};
|
||||
operation::BlockTransferDesc b_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
|
||||
// functions to update fusion operations if they are provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
void update_epilogue(const std::string& epilogue);
|
||||
// returns a templated instance
|
||||
Solution ToSolution() const;
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include "ck/host/types.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace conv {
|
||||
|
||||
// defines the problem specification for a forward convolution operation
|
||||
struct Problem_Conv_Fwd
|
||||
{
|
||||
std::size_t NumDim = 0;
|
||||
// size of a forward convolution operation
|
||||
std::size_t G = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t C = 0;
|
||||
std::size_t Hi = 0;
|
||||
std::size_t Wi = 0;
|
||||
std::size_t Ho = 0;
|
||||
std::size_t Wo = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t Y = 0;
|
||||
std::size_t X = 0;
|
||||
Layout ALayout = Layout::NHWGC;
|
||||
Layout BLayout = Layout::GKYXC;
|
||||
Layout ELayout = Layout::NHWGK;
|
||||
std::vector<Layout> DsLayout = {};
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType EDataType = DataType::Half;
|
||||
std::vector<DataType> DsDataType = {};
|
||||
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough";
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
// returns a list of instances based on the problem spec and provided fusion operations
|
||||
std::vector<Solution> GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const;
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
// holds the templated instance, substitues values into template from instancess
|
||||
struct Solution
|
||||
{
|
||||
|
||||
@@ -33,6 +34,7 @@ struct Solution
|
||||
std::unordered_map<std::string, std::string> template_values;
|
||||
};
|
||||
|
||||
// supported data types
|
||||
enum class DataType
|
||||
{
|
||||
Half,
|
||||
@@ -40,22 +42,28 @@ enum class DataType
|
||||
Int8,
|
||||
Int32
|
||||
};
|
||||
|
||||
std::string ToString(DataType dt);
|
||||
|
||||
// supported layouts: gemm and fwd conv
|
||||
enum class Layout
|
||||
{
|
||||
Row,
|
||||
Column
|
||||
Column,
|
||||
GKYXC,
|
||||
GKCYX,
|
||||
GNHWK,
|
||||
GNHWC,
|
||||
NHWGC,
|
||||
NHWGK
|
||||
};
|
||||
|
||||
std::string ToString(Layout dl);
|
||||
Layout ToLayout(bool Trans); // returns the layout for gemm
|
||||
|
||||
// supported GEMM types
|
||||
enum class GemmType
|
||||
{
|
||||
Default
|
||||
};
|
||||
|
||||
std::string ToString(GemmType gt);
|
||||
|
||||
struct TensorDesc
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
@@ -12,6 +14,5 @@ namespace host {
|
||||
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs();
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
@@ -11,23 +11,28 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
// return the relevant device op file based on the operation
|
||||
std::string Problem::GetIncludeHeader() const
|
||||
{
|
||||
return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
|
||||
}
|
||||
|
||||
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
|
||||
// returns templated instances when provided with a problem specification
|
||||
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const
|
||||
{
|
||||
if(get_xdlop_archs().count(arch) == 0)
|
||||
return {};
|
||||
auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this);
|
||||
auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(
|
||||
*this, prologue, epilogue); // obtains vector of instances
|
||||
std::vector<Solution> result;
|
||||
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
|
||||
return op.ToSolution();
|
||||
return op.ToSolution(); // template instance with correct values
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace device_gemm_multiple_d
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,6 +10,7 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace device_gemm_multiple_d {
|
||||
|
||||
// calculate appropriate Gemm Specification based on input tensor dimensions
|
||||
static std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
@@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m,
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
{
|
||||
this->prologue = prologue;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->prologue = "";
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
{
|
||||
this->epilogue = epilogue;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->epilogue = "";
|
||||
}
|
||||
}
|
||||
|
||||
// accounts for all possible combinations of Row/Col major
|
||||
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(const Problem& prob)
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
const Problem& prob, const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Operation_Xdl_CShuffle> result;
|
||||
|
||||
@@ -155,6 +187,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
// choose correct arrangement of tuning parameters based on the layout of each tensor
|
||||
const auto a_block_descriptions =
|
||||
prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor;
|
||||
const auto b_block_descriptions =
|
||||
@@ -165,6 +198,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
@@ -188,12 +222,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(con
|
||||
x.tile_desc.m_per_block,
|
||||
x.tile_desc.n_per_block,
|
||||
x.tile_desc.k_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>> Operation_Xdl_CShuffle::CreateOperations()
|
||||
// set up instances when not provided with a problem specification, use default operation values and
|
||||
// all possible layout combinations
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Problem> problems;
|
||||
for(bool TransA : {true, false})
|
||||
@@ -204,7 +243,8 @@ std::vector<std::vector<Operation_Xdl_CShuffle>> Operation_Xdl_CShuffle::CreateO
|
||||
prob.TransB = TransB;
|
||||
problems.push_back(prob);
|
||||
}
|
||||
return Transform(problems, [](const Problem& p) { return CreateOperations(p); });
|
||||
return Transform(problems,
|
||||
[&](const Problem& p) { return CreateOperations(p, prologue, epilogue); });
|
||||
}
|
||||
|
||||
static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
|
||||
@@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
|
||||
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
|
||||
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
{
|
||||
std::unordered_map<std::string, std::string> values = {
|
||||
{"name",
|
||||
std::to_string(this->tile_desc.block_size) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.n_Xdl_per_wave)},
|
||||
{"LayoutA", ToString(this->A.layout)},
|
||||
{"LayoutB", ToString(this->B.layout)},
|
||||
{"LayoutDs",
|
||||
|
||||
42
codegen/src/device_grouped_conv_fwd_multiple_abd.cpp
Normal file
42
codegen/src/device_grouped_conv_fwd_multiple_abd.cpp
Normal file
@@ -0,0 +1,42 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace conv {
|
||||
|
||||
// return the relevant device op file based on the operation
|
||||
// NOTE: this is a modified version of the original CK file that calls the kernel from a device
|
||||
// function and makes the Argument class accessible on the device
|
||||
std::string Problem_Conv_Fwd::GetIncludeHeader() const
|
||||
{
|
||||
return "ck/tensor_operation/gpu/device/impl/"
|
||||
"codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp";
|
||||
}
|
||||
|
||||
// return vector of forward convolution instances when provided with a problem instance
|
||||
std::vector<Solution> Problem_Conv_Fwd::GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const
|
||||
{
|
||||
if(get_xdlop_archs().count(arch) == 0)
|
||||
return {};
|
||||
auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(
|
||||
*this, prologue, epilogue);
|
||||
std::vector<Solution> result;
|
||||
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
|
||||
return op.ToSolution();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,364 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include <iostream>
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace conv {
|
||||
|
||||
// calculate appropriate Gemm Specification based on input tensor dimensions
|
||||
// NOTE: in CK, MNKPadding is always used for forward convolution
|
||||
static std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block,
|
||||
const std::size_t k_per_block)
|
||||
{
|
||||
std::string spec = "";
|
||||
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
|
||||
spec += "M";
|
||||
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
|
||||
spec += "N";
|
||||
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
|
||||
spec += "K";
|
||||
if(spec == "")
|
||||
return "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
{
|
||||
this->prologue = prologue;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->prologue = "";
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
{
|
||||
this->epilogue = epilogue;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->epilogue = "";
|
||||
}
|
||||
}
|
||||
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(
|
||||
const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> result;
|
||||
|
||||
std::vector<operation::TileDesc> tile_descriptions = {
|
||||
// clang-format off
|
||||
// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK|
|
||||
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
|
||||
// | | | | | | | | Wave| Wave| Stage|
|
||||
// | | | | | | | | | | |
|
||||
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
|
||||
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::BlockTransferDesc> a_block_descriptions = {
|
||||
// clang-format off
|
||||
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
|
||||
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::BlockTransferDesc> b_block_descriptions = {
|
||||
// clang-format off
|
||||
// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
|
||||
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
|
||||
// clang-format off
|
||||
// CShuffle| CShuffle|
|
||||
// MXdlPerWave| NXdlPerWave|
|
||||
// PerShuffle| PerShuffle|
|
||||
// | |
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
|
||||
// clang-format off
|
||||
// CBlockTransferClusterLengths| CBlockTransfer
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1, 8>, 8}
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
assert(tile_descriptions.size() == a_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == b_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Conv_Fwd_Xdl_Cshuffle x;
|
||||
x.NumDim = prob.NumDim;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b_block_transfer = b_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, prob.ALayout};
|
||||
x.B = TensorDesc{prob.BDataType, prob.BLayout};
|
||||
x.E = TensorDesc{prob.EDataType, prob.ELayout};
|
||||
x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) {
|
||||
return TensorDesc{dt, lo};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// set up instances when not provided with a problem specification, use default operation values
|
||||
std::vector<Operation_Conv_Fwd_Xdl_Cshuffle>
|
||||
Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue,
|
||||
const std::string& epilogue)
|
||||
{
|
||||
Problem_Conv_Fwd prob;
|
||||
return CreateOperations(prob, prologue, epilogue);
|
||||
}
|
||||
|
||||
static const char* const CopyDevice_ConvTemplate =
|
||||
R"(
|
||||
${Prologue}
|
||||
${Epilogue}
|
||||
|
||||
using CDEElementOp = Epilogue;
|
||||
using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>;
|
||||
|
||||
constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors<false, ${ADataType}>();
|
||||
constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors<false, ${BDataType}>();
|
||||
|
||||
extern "C" __global__ void run_${name}(
|
||||
const ${ADataType}* in_dev,
|
||||
const ${BDataType}* wei_dev,
|
||||
${EDataType}* __restrict__ out_dev,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> in_lengths,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> in_strides,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> wei_lengths,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> wei_strides,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> out_lengths,
|
||||
ck::Array<ck::index_t, ${NumDim} + 3> out_strides,
|
||||
ck::Array<ck::index_t, ${NumDim}> conv_filter_strides,
|
||||
ck::Array<ck::index_t, ${NumDim}> conv_filter_dilations,
|
||||
ck::Array<ck::index_t, ${NumDim}> input_left_pads,
|
||||
ck::Array<ck::index_t, ${NumDim}> input_right_pads,
|
||||
const ${AElementwiseOperation} a_element_op,
|
||||
const ${BElementwiseOperation} b_element_op,
|
||||
const ${CDEElementwiseOperation} cde_element_op
|
||||
){
|
||||
|
||||
|
||||
auto arg = DeviceConv::Argument(in_dev,
|
||||
wei_dev,
|
||||
ck::Array<const void*, 0>{},
|
||||
out_dev,
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
ck::Array<ck::Array<ck::index_t, ${NumDim} + 3>, 0>{},
|
||||
ck::Array<ck::Array<ck::index_t, ${NumDim} + 3>, 0>{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
${AElementwiseOperation}{},
|
||||
${BElementwiseOperation}{},
|
||||
${CDEElementwiseOperation}{1.0f, 1.0f});
|
||||
|
||||
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = DeviceConv::GridwiseGemm;
|
||||
|
||||
static constexpr auto I0 = ck::Number<0>{};
|
||||
|
||||
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ${ADataType}*,
|
||||
const ${BDataType}*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
${EDataType},
|
||||
${AElementwiseOperation},
|
||||
${BElementwiseOperation},
|
||||
${CDEElementwiseOperation},
|
||||
DeviceConv::AGridDesc_AK0_M_AK1,
|
||||
DeviceConv::BGridDesc_BK0_N_BK1,
|
||||
DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceConv::Block2ETileMap,
|
||||
ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, 0>,
|
||||
ck::integral_constant<bool, true>{},
|
||||
false,
|
||||
false>
|
||||
(
|
||||
arg.p_as_grid_.At(I0),
|
||||
arg.p_bs_grid_.At(I0),
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_
|
||||
);
|
||||
|
||||
}
|
||||
)";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const
|
||||
{
|
||||
std::unordered_map<std::string, std::string> values = {
|
||||
{"name",
|
||||
std::to_string(this->tile_desc.block_size) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.n_Xdl_per_wave)},
|
||||
{"NumDim", std::to_string(this->NumDim)},
|
||||
{"LayoutA", ToString(this->A.layout)},
|
||||
{"LayoutB", ToString(this->B.layout)},
|
||||
{"LayoutDs",
|
||||
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))},
|
||||
{"LayoutE", ToString(this->E.layout)},
|
||||
{"ADataType", ToString(this->A.element)},
|
||||
{"BDataType", ToString(this->B.element)},
|
||||
{"AccDataType", ToString(this->acc)},
|
||||
{"ComputeDataType", ToString(this->A.element)},
|
||||
{"CShuffleDataType", ToString(this->cs_type)},
|
||||
{"DsDataType",
|
||||
MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))},
|
||||
{"EDataType", ToString(this->E.element)},
|
||||
{"AElementwiseOperation", this->a_elem_op},
|
||||
{"BElementwiseOperation", this->b_elem_op},
|
||||
{"CDEElementwiseOperation", this->cde_elem_op},
|
||||
{"Prologue", this->prologue},
|
||||
{"Epilogue", this->epilogue},
|
||||
{"ConvSpecialization", this->conv_specialization},
|
||||
{"GemmSpecialization", this->gemm_specialization},
|
||||
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
|
||||
{"BlockSize", std::to_string(this->tile_desc.block_size)},
|
||||
{"MPerBlock", std::to_string(this->tile_desc.m_per_block)},
|
||||
{"NPerBlock", std::to_string(this->tile_desc.n_per_block)},
|
||||
{"KPerBlock", std::to_string(this->tile_desc.k_per_block)},
|
||||
{"AK1", std::to_string(this->tile_desc.ak1)},
|
||||
{"BK1", std::to_string(this->tile_desc.bk1)},
|
||||
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
|
||||
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
|
||||
{"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)},
|
||||
{"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)},
|
||||
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
|
||||
this->a_block_transfer.thread_cluster_length},
|
||||
{"ABlockTransferThreadClusterArrangeOrder",
|
||||
this->a_block_transfer.thread_cluster_arrange_order},
|
||||
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
|
||||
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
|
||||
{"ABlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
|
||||
{"ABlockTransferDstScalarPerVector_AK1",
|
||||
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
|
||||
{"BBlockTransferThreadClusterLengths_BK0_N_BK1",
|
||||
this->b_block_transfer.thread_cluster_length},
|
||||
{"BBlockTransferThreadClusterArrangeOrder",
|
||||
this->b_block_transfer.thread_cluster_arrange_order},
|
||||
{"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order},
|
||||
{"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)},
|
||||
{"BBlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->b_block_transfer.src_scalar_per_vector)},
|
||||
{"BBlockTransferDstScalarPerVector_BK1",
|
||||
std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)},
|
||||
{"CShuffleMXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
|
||||
{"CShuffleNXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
|
||||
{"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock",
|
||||
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
|
||||
{"CDEBlockTransferScalarPerVector_NPerBlock",
|
||||
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
|
||||
};
|
||||
|
||||
return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)};
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -14,4 +14,4 @@ std::unordered_map<std::string_view, std::string_view> GetHeaders()
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
} // namespace ck
|
||||
|
||||
@@ -29,12 +29,20 @@ std::string ToString(DataType dt)
|
||||
throw std::runtime_error("Incorrect data type");
|
||||
}
|
||||
|
||||
Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
std::string ToString(Layout dl)
|
||||
{
|
||||
switch(dl)
|
||||
{
|
||||
case Layout::Row: return "ck::tensor_layout::gemm::RowMajor";
|
||||
case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor";
|
||||
case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX";
|
||||
case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC";
|
||||
case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK";
|
||||
case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC";
|
||||
case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC";
|
||||
case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK";
|
||||
}
|
||||
throw std::runtime_error("Incorrect layout");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/utils.hpp"
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
add_subdirectory(rtc)
|
||||
|
||||
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC})
|
||||
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
|
||||
target_include_directories(test_host_${BASE_NAME} PUBLIC include())
|
||||
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC})
|
||||
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
|
||||
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
|
||||
target_include_directories(test_host_${BASE_NAME} PUBLIC include())
|
||||
target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/include)
|
||||
target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include)
|
||||
endforeach()
|
||||
|
||||
134
codegen/test/common.hpp
Normal file
134
codegen/test/common.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
|
||||
std::vector<rtc::src_file> get_headers_for_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
std::size_t GetSize(V mLens, V mStrides)
|
||||
{
|
||||
std::size_t space = 1;
|
||||
for(std::size_t i = 0; i < mLens.Size(); ++i)
|
||||
{
|
||||
if(mLens[i] == 0)
|
||||
continue;
|
||||
|
||||
space += (mLens[i] - 1) * mStrides[i];
|
||||
}
|
||||
return space;
|
||||
}
|
||||
|
||||
template <class T, typename V>
|
||||
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
|
||||
{
|
||||
std::size_t space = GetSize(mLens, mStrides);
|
||||
rtc::buffer<T> result(space);
|
||||
std::mt19937 gen(seed);
|
||||
std::uniform_real_distribution<double> dis(-1.0);
|
||||
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
|
||||
// std::fill(result.begin(), result.end(), 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
|
||||
{
|
||||
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
|
||||
return fabs(x - y) < atol + rtol * fabs(y);
|
||||
});
|
||||
}
|
||||
|
||||
std::string classify(double x)
|
||||
{
|
||||
switch(std::fpclassify(x))
|
||||
{
|
||||
case FP_INFINITE: return "inf";
|
||||
case FP_NAN: return "nan";
|
||||
case FP_NORMAL: return "normal";
|
||||
case FP_SUBNORMAL: return "subnormal";
|
||||
case FP_ZERO: return "zero";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_classification(const Buffer& x)
|
||||
{
|
||||
std::unordered_set<std::string> result;
|
||||
for(const auto& i : x)
|
||||
result.insert(classify(i));
|
||||
for(const auto& c : result)
|
||||
std::cout << c << ", ";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_statistics(const Buffer& x)
|
||||
{
|
||||
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
|
||||
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
|
||||
double num_elements = x.size();
|
||||
auto mean =
|
||||
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
|
||||
auto stddev = std::sqrt(
|
||||
std::accumulate(x.begin(),
|
||||
x.end(),
|
||||
double{0.0},
|
||||
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
|
||||
num_elements);
|
||||
std::cout << "Mean: " << mean << ", ";
|
||||
std::cout << "StdDev: " << stddev << "\n";
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_preview(const Buffer& x)
|
||||
{
|
||||
if(x.size() <= 10)
|
||||
{
|
||||
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
|
||||
}
|
||||
else
|
||||
{
|
||||
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
|
||||
std::cout << "..., ";
|
||||
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct check_all
|
||||
{
|
||||
rtc::buffer<T> data{};
|
||||
bool operator()(const rtc::buffer<T>& x)
|
||||
{
|
||||
if(data.empty())
|
||||
{
|
||||
data = x;
|
||||
return true;
|
||||
}
|
||||
return allclose(data, x);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Solution>
|
||||
auto report(const Solution& solution, bool pass)
|
||||
{
|
||||
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
|
||||
}
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
|
||||
using half = _Float16;
|
||||
// using half = __fp16;
|
||||
@@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel)
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx90a"))
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
|
||||
{
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
@@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel)
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
ck::host::integer_divide_ceil(prob.N, n_per_block);
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
|
||||
|
||||
CHECK(report(solution, check(rtc::from_gpu(c))));
|
||||
}
|
||||
}
|
||||
|
||||
209
codegen/test/grouped_conv_fwd_multiple_d_v1.cpp
Normal file
209
codegen/test/grouped_conv_fwd_multiple_d_v1.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/helper.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include "common.hpp"
|
||||
#include <fstream>
|
||||
|
||||
// Need this for verification
|
||||
/**struct Epilogue
|
||||
{
|
||||
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};**/
|
||||
const std::string conv_compile_check = R"__ck__(
|
||||
#include <${include}>
|
||||
|
||||
${template};
|
||||
|
||||
)__ck__";
|
||||
|
||||
TEST_CASE(test_problem_kernel)
|
||||
{
|
||||
// set up problem specification
|
||||
ck::host::conv::Problem_Conv_Fwd prob;
|
||||
prob.NumDim = 2;
|
||||
prob.G = 32;
|
||||
prob.N = 256;
|
||||
prob.C = 32;
|
||||
prob.K = 64;
|
||||
prob.Y = 3;
|
||||
prob.X = 3;
|
||||
prob.Hi = 28;
|
||||
prob.Wi = 28;
|
||||
prob.Ho = 28;
|
||||
prob.Wo = 28;
|
||||
check_all<ck::half_t> check;
|
||||
|
||||
// user provided fusion operations
|
||||
std::string epilogue = R"(
|
||||
struct Epilogue
|
||||
{
|
||||
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
)";
|
||||
std::string prologue = "";
|
||||
|
||||
// length+stride arrays
|
||||
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi),
|
||||
static_cast<int>(prob.Wi)};
|
||||
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho),
|
||||
static_cast<int>(prob.Wo)};
|
||||
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.Wi * prob.G * prob.C),
|
||||
static_cast<int>(prob.G * prob.C)};
|
||||
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
|
||||
1,
|
||||
static_cast<int>(prob.Wo * prob.G * prob.K),
|
||||
static_cast<int>(prob.G * prob.K)};
|
||||
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
|
||||
static_cast<int>(prob.Y * prob.X * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_left_pads = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_right_pads = {1, 1};
|
||||
|
||||
// move the data onto the device
|
||||
auto in_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
|
||||
auto wei_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
|
||||
auto out_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
|
||||
|
||||
// CK Verficiation: Reference Kernel
|
||||
/**bool pass = true;
|
||||
Tensor<ck::half_t> in_host(in_lengths, in_strides);
|
||||
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
|
||||
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> out_host(out_lengths, out_strides);
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_ = {2, 2};
|
||||
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
|
||||
std::vector<ck::index_t> input_left_pads_ = {1, 1};
|
||||
std::vector<ck::index_t> input_right_pads_ = {1, 1};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
|
||||
2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Epilogue>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in_host,
|
||||
wei_host,
|
||||
out_host,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
|
||||
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
auto name = solution.GetTemplateParameter<std::string>("name");
|
||||
options.kernel_name = "run_" + name;
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
|
||||
// Grid size calculation
|
||||
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
|
||||
|
||||
auto tmp = get_launch_params(solution, out_lengths, out_strides);
|
||||
|
||||
auto grid_size = tmp * in_lengths[1];
|
||||
|
||||
// launch the kernel with arguments needed for the argument pointer
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
|
||||
wei_dev.data(),
|
||||
out_dev.data(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
// auto res = rtc::from_gpu(out_dev);
|
||||
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
// assert(pass);
|
||||
|
||||
// Simple check: this checks that the output from each instance matches the output from the
|
||||
// first instance
|
||||
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
209
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
Normal file
209
codegen/test/grouped_conv_fwd_multiple_d_v2.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/helper.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
|
||||
// need this for validation
|
||||
/**struct Epilogue
|
||||
{
|
||||
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};**/
|
||||
const std::string conv_compile_check = R"__ck__(
|
||||
#include <${include}>
|
||||
|
||||
${template};
|
||||
|
||||
)__ck__";
|
||||
|
||||
TEST_CASE(test_problem_kernel)
|
||||
{
|
||||
// set up problem specification
|
||||
ck::host::conv::Problem_Conv_Fwd prob;
|
||||
prob.NumDim = 2;
|
||||
prob.G = 32;
|
||||
prob.N = 256;
|
||||
prob.C = 32;
|
||||
prob.K = 64;
|
||||
prob.Y = 3;
|
||||
prob.X = 3;
|
||||
prob.Hi = 28;
|
||||
prob.Wi = 28;
|
||||
prob.Ho = 28;
|
||||
prob.Wo = 28;
|
||||
check_all<ck::half_t> check;
|
||||
|
||||
// user provided fusion operations
|
||||
std::string epilogue = R"(
|
||||
struct Epilogue
|
||||
{
|
||||
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
)";
|
||||
std::string prologue = "";
|
||||
|
||||
// length+stride arrays
|
||||
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi),
|
||||
static_cast<int>(prob.Wi)};
|
||||
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho),
|
||||
static_cast<int>(prob.Wo)};
|
||||
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.Wi * prob.G * prob.C),
|
||||
static_cast<int>(prob.G * prob.C)};
|
||||
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
|
||||
1,
|
||||
static_cast<int>(prob.Wo * prob.G * prob.K),
|
||||
static_cast<int>(prob.G * prob.K)};
|
||||
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
|
||||
static_cast<int>(prob.Y * prob.X * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_left_pads = {0, 0};
|
||||
ck::Array<ck::index_t, 2> input_right_pads = {0, 0};
|
||||
|
||||
// move the data onto the device
|
||||
auto in_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
|
||||
auto wei_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
|
||||
auto out_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
|
||||
|
||||
// CK Verficiation: Reference Kernel
|
||||
/**bool pass = true;
|
||||
Tensor<ck::half_t> in_host(in_lengths, in_strides);
|
||||
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
|
||||
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> out_host(out_lengths, out_strides);
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_ = {1, 1};
|
||||
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
|
||||
std::vector<ck::index_t> input_left_pads_ = {0, 0};
|
||||
std::vector<ck::index_t> input_right_pads_ = {0, 0};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
|
||||
2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Epilogue>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in_host,
|
||||
wei_host,
|
||||
out_host,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
|
||||
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
auto name = solution.GetTemplateParameter<std::string>("name");
|
||||
options.kernel_name = "run_" + name;
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
|
||||
// Grid size calculation
|
||||
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
|
||||
|
||||
auto tmp = get_launch_params(solution, out_lengths, out_strides);
|
||||
|
||||
auto grid_size = tmp * in_lengths[1];
|
||||
|
||||
// launch the kernel with arguments needed for the argument pointer
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
|
||||
wei_dev.data(),
|
||||
out_dev.data(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
// auto res = rtc::from_gpu(out_dev);
|
||||
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
// assert(pass);
|
||||
|
||||
// Simple check: this checks that the output from each instance matches the output from the
|
||||
// first instance
|
||||
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
209
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
Normal file
209
codegen/test/grouped_conv_fwd_multiple_d_v3.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/helper.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "common.hpp"
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
|
||||
// need this for verification
|
||||
/**struct Epilogue
|
||||
{
|
||||
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};**/
|
||||
const std::string conv_compile_check = R"__ck__(
|
||||
#include <${include}>
|
||||
|
||||
${template};
|
||||
|
||||
)__ck__";
|
||||
|
||||
TEST_CASE(test_problem_kernel)
|
||||
{
|
||||
// set up problem specification
|
||||
ck::host::conv::Problem_Conv_Fwd prob;
|
||||
prob.NumDim = 2;
|
||||
prob.G = 32;
|
||||
prob.N = 256;
|
||||
prob.C = 32;
|
||||
prob.K = 64;
|
||||
prob.Y = 3;
|
||||
prob.X = 3;
|
||||
prob.Hi = 28;
|
||||
prob.Wi = 28;
|
||||
prob.Ho = 28;
|
||||
prob.Wo = 28;
|
||||
check_all<ck::half_t> check;
|
||||
|
||||
// user provided fusion operations
|
||||
std::string epilogue = R"(
|
||||
struct Epilogue
|
||||
{
|
||||
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
)";
|
||||
std::string prologue = "";
|
||||
|
||||
// length+stride arrays
|
||||
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi),
|
||||
static_cast<int>(prob.Wi)};
|
||||
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho),
|
||||
static_cast<int>(prob.Wo)};
|
||||
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.Wi * prob.G * prob.C),
|
||||
static_cast<int>(prob.G * prob.C)};
|
||||
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
|
||||
1,
|
||||
static_cast<int>(prob.Wo * prob.G * prob.K),
|
||||
static_cast<int>(prob.G * prob.K)};
|
||||
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
|
||||
static_cast<int>(prob.Y * prob.X * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_left_pads = {0, 0};
|
||||
ck::Array<ck::index_t, 2> input_right_pads = {0, 0};
|
||||
|
||||
// move the data onto the device
|
||||
auto in_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
|
||||
auto wei_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
|
||||
auto out_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
|
||||
|
||||
// CK Verficiation: Reference Kernel
|
||||
/**bool pass = true;
|
||||
Tensor<ck::half_t> in_host(in_lengths, in_strides);
|
||||
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
|
||||
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> out_host(out_lengths, out_strides);
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_ = {2, 2};
|
||||
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
|
||||
std::vector<ck::index_t> input_left_pads_ = {0, 0};
|
||||
std::vector<ck::index_t> input_right_pads_ = {0, 0};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
|
||||
2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Epilogue>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in_host,
|
||||
wei_host,
|
||||
out_host,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
|
||||
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
auto name = solution.GetTemplateParameter<std::string>("name");
|
||||
options.kernel_name = "run_" + name;
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
|
||||
// Grid size calculation
|
||||
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
|
||||
|
||||
auto tmp = get_launch_params(solution, out_lengths, out_strides);
|
||||
|
||||
auto grid_size = tmp * in_lengths[1];
|
||||
|
||||
// launch the kernel with arguments needed for the argument pointer
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
|
||||
wei_dev.data(),
|
||||
out_dev.data(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
// auto res = rtc::from_gpu(out_dev);
|
||||
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
// assert(pass);
|
||||
|
||||
// Simple check: this checks that the output from each instance matches the output from the
|
||||
// first instance
|
||||
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
209
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
Normal file
209
codegen/test/grouped_conv_fwd_multiple_d_v4.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/helper.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "common.hpp"
|
||||
#include <test.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
|
||||
// need this for verification
|
||||
/**struct Epilogue
|
||||
{
|
||||
Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};**/
|
||||
const std::string conv_compile_check = R"__ck__(
|
||||
#include <${include}>
|
||||
|
||||
${template};
|
||||
|
||||
)__ck__";
|
||||
|
||||
TEST_CASE(test_problem_kernel)
|
||||
{
|
||||
// set up problem specification
|
||||
ck::host::conv::Problem_Conv_Fwd prob;
|
||||
prob.NumDim = 2;
|
||||
prob.G = 32;
|
||||
prob.N = 256;
|
||||
prob.C = 32;
|
||||
prob.K = 64;
|
||||
prob.Y = 3;
|
||||
prob.X = 3;
|
||||
prob.Hi = 28;
|
||||
prob.Wi = 28;
|
||||
prob.Ho = 28;
|
||||
prob.Wo = 28;
|
||||
check_all<ck::half_t> check;
|
||||
|
||||
// user provided fusion operations
|
||||
std::string epilogue = R"(
|
||||
struct Epilogue
|
||||
{
|
||||
__host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, ck::half_t>(ck::half_t& e,
|
||||
const ck::half_t& d) const
|
||||
{
|
||||
e = ck::type_convert<ck::half_t>(alpha_ * e + beta_ * ck::type_convert<float>(d));
|
||||
}
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
)";
|
||||
std::string prologue = "";
|
||||
|
||||
// length+stride arrays
|
||||
ck::Array<ck::index_t, 5> in_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi),
|
||||
static_cast<int>(prob.Wi)};
|
||||
ck::Array<ck::index_t, 5> out_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.N),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho),
|
||||
static_cast<int>(prob.Wo)};
|
||||
ck::Array<ck::index_t, 5> wei_lengths{static_cast<int>(prob.G),
|
||||
static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.Wi * prob.G * prob.C),
|
||||
static_cast<int>(prob.G * prob.C)};
|
||||
ck::Array<ck::index_t, 5> out_strides{static_cast<int>(prob.K),
|
||||
static_cast<int>(prob.Ho * prob.Wo * prob.G * prob.K),
|
||||
1,
|
||||
static_cast<int>(prob.Wo * prob.G * prob.K),
|
||||
static_cast<int>(prob.G * prob.K)};
|
||||
ck::Array<ck::index_t, 5> wei_strides{static_cast<int>(prob.K * prob.Y * prob.X * prob.C),
|
||||
static_cast<int>(prob.Y * prob.X * prob.C),
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_left_pads = {1, 1};
|
||||
ck::Array<ck::index_t, 2> input_right_pads = {1, 1};
|
||||
|
||||
// move the data onto the device
|
||||
auto in_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(in_lengths, in_strides, 0));
|
||||
auto wei_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(wei_lengths, wei_strides, 1));
|
||||
auto out_dev =
|
||||
to_gpu(generate_buffer<ck::half_t, ck::Array<ck::index_t, 5>>(out_lengths, out_strides, 2));
|
||||
|
||||
// CK Verficiation: Reference Kernel
|
||||
/**bool pass = true;
|
||||
Tensor<ck::half_t> in_host(in_lengths, in_strides);
|
||||
in_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> wei_host(wei_lengths, wei_strides);
|
||||
wei_host.GenerateTensorValue(GeneratorTensor_1<ck::half_t>{1});
|
||||
Tensor<ck::half_t> out_host(out_lengths, out_strides);
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides_ = {1, 1};
|
||||
std::vector<ck::index_t> conv_filter_dilations_ = {1, 1};
|
||||
std::vector<ck::index_t> input_left_pads_ = {1, 1};
|
||||
std::vector<ck::index_t> input_right_pads_ = {1, 1};
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
|
||||
2,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Epilogue>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in_host,
|
||||
wei_host,
|
||||
out_host,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
Epilogue{1.0f, 1.0f});
|
||||
out_host.SetZero();
|
||||
ref_invoker.Run(ref_argument);**/
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
|
||||
{
|
||||
// substitute instance values into the template
|
||||
auto src = ck::host::InterpolateString(
|
||||
conv_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}});
|
||||
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
auto name = solution.GetTemplateParameter<std::string>("name");
|
||||
options.kernel_name = "run_" + name;
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
|
||||
// Grid size calculation
|
||||
auto block_size = solution.GetTemplateParameter<ck::index_t>("BlockSize");
|
||||
|
||||
auto tmp = get_launch_params(solution, out_lengths, out_strides);
|
||||
|
||||
auto grid_size = tmp * in_lengths[1];
|
||||
|
||||
// launch the kernel with arguments needed for the argument pointer
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(),
|
||||
wei_dev.data(),
|
||||
out_dev.data(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
// auto res = rtc::from_gpu(out_dev);
|
||||
// pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
// assert(pass);
|
||||
|
||||
// Simple check: this checks that the output from each instance matches the output from the
|
||||
// first instance
|
||||
CHECK(report(solution, check(rtc::from_gpu(out_dev))));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
@@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer)
|
||||
}
|
||||
|
||||
std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; }
|
||||
// TODO: undo after extracting the codeobj
|
||||
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
@@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
|
||||
|
||||
auto obj = read_buffer(out_path.string());
|
||||
|
||||
std::ofstream ofh("obj.o", std::ios::binary);
|
||||
for(auto i : obj)
|
||||
ofh << i;
|
||||
ofh.close();
|
||||
// int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str());
|
||||
// assert(s == 0);
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include <rtc/manage_ptr.hpp>
|
||||
#include <stdexcept>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
@@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory()
|
||||
size_t total;
|
||||
auto status = hipMemGetInfo(&free, &total);
|
||||
if(status != hipSuccess)
|
||||
throw std::runtime_error("Failed getting available memory: " + hip_error(status));
|
||||
{
|
||||
std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl;
|
||||
return (8ull * 1024ull * 1024ull * 1024ull);
|
||||
}
|
||||
return free;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.4.0
|
||||
rocm-docs-core==1.4.1
|
||||
sphinxcontrib-bibtex==2.6.2
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.31.0
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.4.0
|
||||
rocm-docs-core==1.4.1
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
|
||||
@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3)
|
||||
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
|
||||
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
|
||||
|
||||
@@ -7,3 +7,21 @@
|
||||
#arg3: run kernel # of times (>1)
|
||||
./bin/example_gemm_xdl 0 1 5
|
||||
```
|
||||
|
||||
# Instructions for ```example_gemm_xdl_fp16_streamk_v3```
|
||||
|
||||
## Run ```example_gemm_xdl_fp16_streamk_v3```
|
||||
```bash
|
||||
arg1: verification (0=no, 1=yes)
|
||||
arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
arg3: time kernel (0=no, 1=yes)
|
||||
arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
|
||||
arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)
|
||||
arg11: Grid_size(-1 for max occupancy)
|
||||
bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1}
|
||||
Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal<MNPadding, RRR> BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2
|
||||
```
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
struct ProblemSizeStreamK_universal final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
{
|
||||
@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeStreamK_universal& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.Streamk_sel = std::stoi(argv[10]);
|
||||
problem_size.Grid_size = std::stoi(argv[11]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
char* argv[],
|
||||
@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: NumSKBlocks(optional)" << std::endl;
|
||||
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
|
||||
< ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
< ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1, // Prefetch stage
|
||||
128, // BlockSize
|
||||
64, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
8, // K1
|
||||
2, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
|
||||
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
true,
|
||||
1, // C shuffle (M Repeat) Per store
|
||||
1, // C shuffle (N Repeat) Per store
|
||||
S<1, 32, 1, 4>,
|
||||
S<1, 32, 1, 4>,
|
||||
8>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
48
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
Normal file
48
example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2_Streamk_Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
224, 256,
|
||||
64, 8, 2,
|
||||
16, 16,
|
||||
7, 8,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 8, 2, 0,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_streamk_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); }
|
||||
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
case 4:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
|
||||
break;
|
||||
case 5:
|
||||
|
||||
298
example/01_gemm/run_gemm_example_streamk_v2.inc
Normal file
298
example/01_gemm/run_gemm_example_streamk_v2.inc
Normal file
@@ -0,0 +1,298 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
#endif
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto Grid_size = problem_size.Grid_size;
|
||||
auto Streamk_sel = problem_size.Streamk_sel;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) {
|
||||
if(streamk_sel == -1)
|
||||
{
|
||||
return static_cast<std::size_t>(4);
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(streamk_sel);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Streamk_sel = f_get_default_streamk_policy(Streamk_sel);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
|
||||
c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
|
||||
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
|
||||
#else
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2_Streamk_Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
Streamk_sel,
|
||||
Grid_size,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
|
||||
|
||||
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_universal_streamk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeStreamK_universal problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
false,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
false,
|
||||
1,
|
||||
1,
|
||||
S<1, 64, 1, 2>,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
// kernel data types
|
||||
using InKernelDataType = FP16;
|
||||
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
// kernel data types
|
||||
using InKernelDataType = I8;
|
||||
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
|
||||
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
|
||||
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
#define CK_MHA_USE_WAVE_2
|
||||
#define CK_MHA_USE_WAVE_4
|
||||
#define CK_MHA_USE_WAVE_8
|
||||
//#define CK_MHA_USE_WAVE_8
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
MaskingSpec>,
|
||||
MaskingSpec>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
||||
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
|
||||
#include "run_cross_attention_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ Example is GQA-4
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
|
||||
|
||||
#include "run_grouped_query_attention_forward_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
|
||||
|
||||
#include "run_multi_query_attention_forward_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
#define CK_MHA_USE_WAVE_2
|
||||
#define CK_MHA_USE_WAVE_4
|
||||
#define CK_MHA_USE_WAVE_8
|
||||
//#define CK_MHA_USE_WAVE_8
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
@@ -277,10 +278,10 @@ using DeviceMHAFactory =
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
MaskingSpec>,
|
||||
MaskingSpec>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
||||
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
|
||||
#include "run_self_attention_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run_grouped_conv_bwd_data_example(argc, argv);
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
@@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
|
||||
# add all example subdir
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
FOREACH(subdir ${dir_list})
|
||||
IF(IS_DIRECTORY "${subdir}")
|
||||
if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt")
|
||||
add_subdirectory(${subdir})
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
@@ -1,7 +1,27 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(EXAMPLE_FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
|
||||
set(EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${EXAMPLE_FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(EXAMPLE_FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(EXAMPLE_FMHA_FWD_ENABLE_APIS ${EXAMPLE_FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
foreach(api ${EXAMPLE_FMHA_FWD_ENABLE_APIS})
|
||||
if(NOT "${api}" IN_LIST EXAMPLE_FMHA_FWD_KNOWN_APIS)
|
||||
message(FATAL_ERROR "${api} isn't a known api: ${EXAMPLE_FMHA_FWD_KNOWN_APIS}.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
|
||||
if(NOT "fwd" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
|
||||
string(REPLACE ";" "," EXAMPLE_FMHA_FWD_APIS "${EXAMPLE_FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd,fwd_appendkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${EXAMPLE_FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
)
|
||||
|
||||
execute_process(
|
||||
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api fwd,fwd_appendkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${EXAMPLE_FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
@@ -61,6 +81,13 @@ else()
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
|
||||
if ("fwd_splitkv" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
endif()
|
||||
|
||||
# Allow comparing floating points directly in order to check sentinel values
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
671
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Normal file
671
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Normal file
@@ -0,0 +1,671 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdTileSize,
|
||||
FmhaFwdApiTrait,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
struct kernel_runner {{
|
||||
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
||||
using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
||||
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
|
||||
fmha_block_warps,
|
||||
fmha_warp_tile,
|
||||
fmha_block_warps,
|
||||
fmha_warp_tile,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
kHasUnevenSplits,
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
fmha_shape,
|
||||
{F_mode},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait>;
|
||||
|
||||
using fmha_pipeline = {F_pipeline}<
|
||||
fmha_pipeline_problem>;
|
||||
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
kernel_runner<false>::run(s, a);
|
||||
}} else {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
kernel_runner<true>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
namespace {{
|
||||
template <ck_tile::index_t kLogMaxSplits>
|
||||
struct kernel_runner {{
|
||||
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_lse},
|
||||
{F_squant},
|
||||
kLogMaxSplits,
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
{F_hdim},
|
||||
{F_bm0},
|
||||
{F_bn1},
|
||||
{F_mode},
|
||||
fmha_trait>;
|
||||
|
||||
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
|
||||
fmha_pipeline_problem>;
|
||||
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
|
||||
fmha_pipeline,
|
||||
fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1},
|
||||
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
if (a.num_splits <= 16) {{
|
||||
kernel_runner<4>::run(s, a);
|
||||
}} else if (a.num_splits <= 32) {{
|
||||
kernel_runner<5>::run(s, a);
|
||||
}} else if (a.num_splits <= 64) {{
|
||||
kernel_runner<6>::run(s, a);
|
||||
}} else if (a.num_splits <= 128) {{
|
||||
kernel_runner<7>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
|
||||
FMHA_FWD_SPLITKV_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
return n
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombinePipeline:
|
||||
tag : str
|
||||
|
||||
F_spad : str # true/false
|
||||
F_dvpad : str #
|
||||
F_lse : str #
|
||||
F_squant : str #
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
return n
|
||||
|
||||
class FmhaFwdSplitKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineTileSize:
|
||||
F_bm0 : int # tile size along q seqlen
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdSplitKVPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0blen=self.F_tile.F_bk0blen,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdSplitKVCombineTileSize
|
||||
F_pipeline : FmhaFwdSplitKVCombinePipeline
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mode = MODE_MAP[self.F_mode])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0blen=self.F_tile.F_bk0blen,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1),
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
Kernel = FmhaFwdSplitKVKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# splitkv kernel donot support dropout
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1:
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
Pipeline = FmhaFwdSplitKVCombinePipeline
|
||||
Kernel = FmhaFwdSplitKVCombineKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(Pipeline('unused', spad, dvpad, lse, squant))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
|
||||
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
|
||||
file_path.write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
|
||||
@@ -118,6 +118,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("num_splits",
|
||||
"1",
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -159,6 +162,108 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
return num_splits;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
fmha_fwd_args args,
|
||||
const ck_tile::stream_config& config)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < args.num_splits)
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
return fmha_fwd(traits, args, config);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -266,6 +371,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int num_splits = arg_parser.get_int("num_splits");
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
@@ -326,6 +433,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// legalize num_splits according to other options
|
||||
if(num_splits < 1)
|
||||
{
|
||||
num_splits = override_num_splits_if_necessary(
|
||||
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
|
||||
}
|
||||
if(128 < num_splits)
|
||||
{
|
||||
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_lengths = [&](bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
@@ -375,7 +494,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// self define lse data layout as [batch, nhead, max_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
@@ -462,6 +589,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
@@ -500,7 +629,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
<< ", mask:" << mask << ", v:" << vlayout;
|
||||
if(1 < num_splits)
|
||||
{
|
||||
std::cout << ", num_splits:" << num_splits;
|
||||
}
|
||||
std::cout << std::flush;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -633,6 +767,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = hdim_v;
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
@@ -647,6 +782,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
@@ -655,7 +792,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
@@ -663,6 +805,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
lse_acc_buf.GetDeviceBuffer(),
|
||||
o_acc_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
@@ -676,6 +820,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
num_splits,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
@@ -685,6 +830,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
stride_randval,
|
||||
stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
@@ -692,6 +838,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -699,7 +847,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
@@ -708,7 +860,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
ave_time += fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
ave_time += fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
@@ -997,14 +1149,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= lse_pass;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
|
||||
@@ -93,6 +93,8 @@ struct fmha_fwd_args
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
@@ -106,6 +108,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t num_splits;
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
@@ -114,6 +117,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o_acc;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
@@ -121,6 +125,8 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -128,7 +134,11 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
@@ -137,6 +147,50 @@ struct fmha_fwd_args
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_vnew;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_knew;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
@@ -234,49 +288,175 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
const void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
dim3 grids =
|
||||
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits);
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool is_rotary_interleaved;
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel argumentszs
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_vnew;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_knew;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
@@ -400,6 +580,40 @@ struct fmha_fwd_traits_
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN1_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_combine_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
@@ -450,6 +664,7 @@ struct fmha_fwd_traits
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_appendkv_traits
|
||||
{
|
||||
|
||||
@@ -11,6 +11,7 @@ from codegen.cmake_config import *
|
||||
from codegen.ops import (
|
||||
fmha_fwd,
|
||||
fmha_fwd_appendkv,
|
||||
fmha_fwd_splitkv,
|
||||
fmha_bwd
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ class HandlerId(IntEnum):
|
||||
handlers = {
|
||||
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs),
|
||||
'fwd_appendkv' : (fmha_fwd_appendkv.list_blobs, fmha_fwd_appendkv.write_blobs),
|
||||
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs),
|
||||
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs),
|
||||
}
|
||||
|
||||
@@ -103,4 +105,4 @@ if __name__ == "__main__":
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
4
example/ck_tile/02_layernorm2d/CMakeLists.txt
Normal file
4
example/ck_tile/02_layernorm2d/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD)
|
||||
22
example/ck_tile/02_layernorm2d/README.md
Normal file
22
example/ck_tile/02_layernorm2d/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Layernorm2D forward
|
||||
|
||||
This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_layernorm2d_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-e epsilon (default:1e-5)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
```
|
||||
191
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
Normal file
191
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
Normal file
@@ -0,0 +1,191 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// Host API implementation
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
#else
|
||||
using MeanDataType = ck_tile::null_type;
|
||||
using InvStdDataType = ck_tile::null_type;
|
||||
#endif
|
||||
using ComputeDataType = float;
|
||||
|
||||
using thread_tile = ck_tile::sequence<4, 4>;
|
||||
using warp_tile = ck_tile::sequence<8, 128>;
|
||||
using block_tile = ck_tile::sequence<32, 128>;
|
||||
|
||||
using Shape = ck_tile::TileLayernorm2dShape<thread_tile, warp_tile, block_tile>;
|
||||
|
||||
using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType,
|
||||
Shape>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(
|
||||
a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a.M);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "m dimension")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
#else
|
||||
using MeanDataType = ck_tile::null_type;
|
||||
using InvStdDataType = ck_tile::null_type;
|
||||
#endif
|
||||
using ComputeDataType = float;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({M, N});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({N});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({N});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({M, N});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({M, N});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-5.f, 5.f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-5.f, 5.f}(beta_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
|
||||
#endif
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
|
||||
layernorm2d_fwd_traits traits{data_type};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
y_buf.GetDeviceBuffer(),
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
mean_buf.GetDeviceBuffer(),
|
||||
invStd_buf.GetDeviceBuffer(),
|
||||
#else
|
||||
nullptr,
|
||||
nullptr,
|
||||
#endif
|
||||
epsilon,
|
||||
M,
|
||||
N};
|
||||
|
||||
float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true});
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s"
|
||||
<< std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(
|
||||
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
#ifdef SAVE_MEAN_INV_STD
|
||||
mean_buf.FromDevice(mean_host_dev.data());
|
||||
pass &= ck_tile::check_err(mean_host_dev, mean_host_ref);
|
||||
|
||||
invStd_buf.FromDevice(invStd_host_dev.data());
|
||||
pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref);
|
||||
#endif
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
|
||||
std::cout << std::endl << std::flush;
|
||||
|
||||
return !pass;
|
||||
}
|
||||
30
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
Normal file
30
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/layernorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
struct layernorm2d_fwd_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
struct layernorm2d_fwd_args
|
||||
{
|
||||
const void* p_x;
|
||||
const void* p_gamma;
|
||||
const void* p_beta;
|
||||
void* p_y;
|
||||
void* p_mean;
|
||||
void* p_invStd;
|
||||
float epsilon;
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
};
|
||||
|
||||
// host API
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
@@ -3,3 +3,4 @@ include_directories(AFTER
|
||||
)
|
||||
|
||||
add_subdirectory(01_fmha)
|
||||
add_subdirectory(02_layernorm2d)
|
||||
|
||||
@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
// buffer resource
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx103__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx11__)
|
||||
#elif defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
#elif defined(__gfx11__)
|
||||
#elif defined(__gfx11__) || defined(__gfx12__)
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
|
||||
@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#define CK_USE_AMD_MFMA_GFX940
|
||||
#endif
|
||||
|
||||
// WMMA instruction
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_USE_AMD_WMMA
|
||||
#elif defined(__gfx11__) // for GPU code
|
||||
#define CK_USE_AMD_WMMA
|
||||
#endif
|
||||
|
||||
// buffer load
|
||||
#define CK_USE_AMD_BUFFER_LOAD 1
|
||||
|
||||
|
||||
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
|
||||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
|
||||
}
|
||||
|
||||
inline bool is_gfx12_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -13,6 +13,504 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#ifdef __gfx12__
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWMMA,
|
||||
index_t NPerWMMA,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool AEnableLds = true,
|
||||
bool BEnableLds = true,
|
||||
bool TransposeC = false>
|
||||
/* Option: Read from LDS, big buffer hold all threads required data
|
||||
* Source
|
||||
* A: K0PerBlock x MPerBlock x K1
|
||||
* B: K0PerBlock x NPerBlock x K1
|
||||
* Destination
|
||||
* C, non-transpose
|
||||
* thread level: MRepeat x NRepeat x MAccVgprs
|
||||
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
* KPACK == WMMA_K = 16
|
||||
*
|
||||
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
|
||||
* Source:
|
||||
* A(if skip LDS): MRepeat x KPack
|
||||
* B(if skip LDS): NRepeat x KPack
|
||||
* Destination
|
||||
* C, non-transpose
|
||||
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
*/
|
||||
struct BlockwiseGemmWMMA
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto WmmaK = Number<16>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
|
||||
static constexpr index_t WaveSize = 32;
|
||||
|
||||
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
|
||||
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
|
||||
// permutation
|
||||
static constexpr index_t A_KRow = 2;
|
||||
static constexpr index_t B_KRow = 2;
|
||||
|
||||
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5);
|
||||
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5);
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
wmma_gemm.GetRegSizePerWmma(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
// Default, Block buffer in LDS, thread level offset enabled
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
|
||||
|
||||
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
|
||||
|
||||
return make_tuple(
|
||||
Number<m0>{}, waveId_m, blk_idx[I0], Number<n0>{}, waveId_n, blk_idx[I1], blk_idx[I2]);
|
||||
}
|
||||
|
||||
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
|
||||
__host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWMMA * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
// Thread level, register decriptor. Vector-write
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
|
||||
return make_naive_tensor_descriptor(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
|
||||
make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
Number<NRepeat>{} * MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
MAccVgprs * AccStride,
|
||||
AccStride));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// Provide dimension size
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// Describe how data allocated in thread copy src buffer
|
||||
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
|
||||
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
|
||||
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
|
||||
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
|
||||
|
||||
// basic intrinsic to determine loopover direction
|
||||
if constexpr(MRepeat < NRepeat)
|
||||
{
|
||||
static_for<0, KPerBlock / KPack, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
||||
});
|
||||
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of
|
||||
// k=0,kpack*1, ..
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
vector_type<FloatA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<FloatB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, 0, i % A_K1))>{}];
|
||||
});
|
||||
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<FloatA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<FloatB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<A_K1>{},
|
||||
Number<A_K1>{},
|
||||
Number<A_K1>{},
|
||||
Number<1>{}));
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<B_K1>{},
|
||||
Number<B_K1>{},
|
||||
Number<B_K1>{},
|
||||
Number<1>{}));
|
||||
|
||||
// C[M, N, NumRegWMMA]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
|
||||
|
||||
template <bool EnableLds>
|
||||
struct AThreadCopySelector;
|
||||
|
||||
template <>
|
||||
struct AThreadCopySelector<true>
|
||||
{
|
||||
using type =
|
||||
ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AThreadCopySelector<false>
|
||||
{
|
||||
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
false>;
|
||||
};
|
||||
|
||||
template <bool EnableLds>
|
||||
struct BThreadCopySelector;
|
||||
|
||||
template <>
|
||||
struct BThreadCopySelector<true>
|
||||
{
|
||||
using type =
|
||||
ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BThreadCopySelector<false>
|
||||
{
|
||||
using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
B_K1,
|
||||
false>;
|
||||
};
|
||||
|
||||
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
||||
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
||||
};
|
||||
#else
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -527,5 +1025,6 @@ struct BlockwiseGemmWMMA
|
||||
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
|
||||
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
// sync point.
|
||||
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("\
|
||||
s_barrier_signal -1 \n \
|
||||
s_barrier_wait -1 \
|
||||
" ::);
|
||||
#else
|
||||
asm volatile("s_barrier" ::);
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_Streamk_V2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t Streamk_sel,
|
||||
ck::index_t Grid_size,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
@@ -0,0 +1,359 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include <fstream>
|
||||
#include <variant>
|
||||
|
||||
// functions to return the corresponding structs based on generated template parameters
|
||||
|
||||
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tensor_layout::convolution::NDHWGK>;
|
||||
// return the layout type: currently this is the only type supported in MIOpen
|
||||
auto layout_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_layout::convolution::NHWGK")
|
||||
{
|
||||
return ck::tensor_layout::convolution::NHWGK{};
|
||||
}
|
||||
throw std::runtime_error("Incorrect layout");
|
||||
}
|
||||
// return the right gemm spec based on the generated template parameters
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
throw std::runtime_error("Incorrect gemm spec: " + type);
|
||||
}
|
||||
|
||||
// return the type of convolution
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
if(type ==
|
||||
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec: " + type);
|
||||
}
|
||||
|
||||
// Function to call on MatrixPadder via a wrapper struct
|
||||
// NOTE: CK only uses MNKPadding for forward convolution
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
auto pad(ck::index_t mpb,
|
||||
ck::index_t npb,
|
||||
ck::index_t kpb,
|
||||
ck::tensor_operation::device::GemmSpecialization gemm,
|
||||
CDesc_MRaw_NRaw conv)
|
||||
{
|
||||
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
ck::tensor_operation::device::MatrixPadder<
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
ck::index_t,
|
||||
ck::index_t,
|
||||
ck::index_t>
|
||||
a;
|
||||
a.MPerTile_ = mpb;
|
||||
a.NPerTile_ = npb;
|
||||
a.KPerTile_ = kpb;
|
||||
auto tmp = grid_desc(a, conv);
|
||||
return tmp;
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters, check gemm spec");
|
||||
}
|
||||
|
||||
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
|
||||
// dims
|
||||
// FIXME: add a way to properly pass in the layout
|
||||
auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect dims or conv spec");
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
|
||||
{
|
||||
if(m_per_block == 32 && n_per_block == 64)
|
||||
{
|
||||
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 32 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 256)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 256 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters");
|
||||
}
|
||||
|
||||
// wrapper functions by dims to get grid size - uses above 3 functions
|
||||
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
|
||||
auto get_launch_params_1d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params_3d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
@@ -0,0 +1,781 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
|
||||
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
|
||||
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename AsPointer, // tuples if multi AB, pointers if no
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
AsPointer p_as_grid_grp;
|
||||
BsPointer p_bs_grid_grp;
|
||||
|
||||
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
|
||||
|
||||
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid_grp,
|
||||
p_bs_grid_grp,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid + a_batch_offset,
|
||||
p_bs_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#else
|
||||
ignore = p_as_grid;
|
||||
ignore = p_bs_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename AsPointer, // tuples if multi AB, pointers if no
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
|
||||
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
AsPointer, // tuples if multi AB, pointers if no
|
||||
BsPointer,
|
||||
DsPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfBatch,
|
||||
HasMainKBlockLoop,
|
||||
isMultiA,
|
||||
isMultiB>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
*p_e_grid,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
batch_count,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map,
|
||||
compute_ptr_offset_of_batch);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
|
||||
// it to it
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
|
||||
// in initializer list what is required for single const pointer).
|
||||
using AGridPointer = remove_cvref_t<
|
||||
decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
|
||||
using BGridPointer = remove_cvref_t<
|
||||
decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument
|
||||
{
|
||||
__device__ __host__ Argument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
: p_as_grid_{},
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiA)
|
||||
{
|
||||
// p_as is tuple
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiB)
|
||||
{
|
||||
// p_bs is tuple
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiA and not MultiB then p_bs is single pointer
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
|
||||
}
|
||||
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
|
||||
});
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
// populate desc for Ds/E
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers (tuple if multi AB, pointer if no)
|
||||
AGridPointer p_as_grid_;
|
||||
BGridPointer p_bs_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
ck::Array<index_t, NDimSpatial> input_left_pads_;
|
||||
ck::Array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
if(!(arg.a_kz_stride_ == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
index_t LastK =
|
||||
AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6);
|
||||
if(LastK % ABlockTransferSrcScalarPerVector == 0)
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,8 +70,9 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -56,7 +56,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -159,6 +159,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -187,7 +188,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -321,7 +322,7 @@ __global__ void
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
||||
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
||||
std::is_same<ADataType, double>::value)
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported())
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
|
||||
@@ -50,8 +50,9 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
|
||||
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
? false
|
||||
: true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
? false
|
||||
: true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -0,0 +1,556 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
dim3 grid_dim;
|
||||
if(arg.Grid_size < 0)
|
||||
{
|
||||
int occupancy, num_cu;
|
||||
hipError_t rtn;
|
||||
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occupancy, kernel, BlockSize, 0);
|
||||
hip_check_error(rtn);
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
rtn = hipGetDevice(&dev);
|
||||
hip_check_error(rtn);
|
||||
rtn = hipGetDeviceProperties(&dev_prop, dev);
|
||||
hip_check_error(rtn);
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
arg.Grid_size = num_cu * occupancy;
|
||||
grid_dim = arg.Grid_size;
|
||||
}
|
||||
else
|
||||
grid_dim = arg.Grid_size;
|
||||
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.M * arg_.K * sizeof(ADataType),
|
||||
arg_.K * arg_.N * sizeof(BDataType));
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
|
||||
return Argument{
|
||||
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t streamk_sel,
|
||||
index_t Grid_size,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
streamk_sel,
|
||||
Grid_size);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -48,8 +48,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
@@ -47,12 +47,12 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
@@ -103,12 +103,12 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const index_t num_k_per_block)
|
||||
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
[[maybe_unused]] const index_t num_k_per_block)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
|
||||
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -90,8 +90,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
|
||||
defined(__gfx11__))
|
||||
defined(__gfx11__) || defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -69,14 +69,15 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
@@ -132,13 +133,13 @@ __global__ void
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
const index_t groups_count)
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
|
||||
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t groups_count)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
|
||||
@@ -582,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -39,8 +39,9 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
|
||||
defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
|
||||
@@ -61,7 +61,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -166,6 +166,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -60,7 +60,7 @@ __global__ void
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
// clang-format off
|
||||
// ***************************************************
|
||||
@@ -165,6 +165,7 @@ __global__ void
|
||||
ignore = O;
|
||||
ignore = G0;
|
||||
ignore = G1;
|
||||
ignore = alpha;
|
||||
ignore = input_permute;
|
||||
ignore = output_permute;
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
|
||||
{
|
||||
};
|
||||
|
||||
// function to take in a struct of type MatrixPadder and call the appropriate function to get
|
||||
// the output descriptor at runtime for codegen
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType,
|
||||
typename CDesc_MRaw_NRaw>
|
||||
auto grid_desc(MatrixPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType> matrix_padder,
|
||||
CDesc_MRaw_NRaw conv_desc)
|
||||
{
|
||||
auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
|
||||
return res;
|
||||
}
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
|
||||
@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
|
||||
}
|
||||
};
|
||||
|
||||
template <uint32_t MPerBlock_,
|
||||
uint32_t NPerBlock_,
|
||||
uint32_t KPerBlock_,
|
||||
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM_ = 8,
|
||||
index_t GroupNum = 8,
|
||||
index_t M01_ = 4>
|
||||
struct BlockToCTileMap_GemmStreamK_v2
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
mutable uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
MDiv2 n_tiles;
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv equiv_tiles_big; // for reduction
|
||||
MDiv equiv_tiles_little; // for reduction
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
|
||||
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
|
||||
{
|
||||
// total output tiles
|
||||
uint32_t num_tiles =
|
||||
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
|
||||
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
|
||||
|
||||
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
|
||||
|
||||
// default to regular DP GEMM if sk blocks == 0
|
||||
if(streamk_sel == 0)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
dp_tiles = num_tiles;
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
// 2-tile sk + DP GEMM
|
||||
else
|
||||
{
|
||||
|
||||
// check if there's enough work for DP+ stream-k
|
||||
bool bigEnough = num_tiles > grid_size;
|
||||
// select between stream-k strategies
|
||||
uint32_t sk_tiles = 0;
|
||||
if(streamk_sel == 1) // 1 tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 2) // 2-tile stream-k
|
||||
{
|
||||
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 3) // 3-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
else if(streamk_sel == 4) // 4-tile stream-k
|
||||
{
|
||||
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
|
||||
: num_tiles;
|
||||
}
|
||||
sk_num_blocks = sk_tiles;
|
||||
// remaining tiles are DP tiles
|
||||
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
|
||||
|
||||
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = sk_num_blocks;
|
||||
}
|
||||
|
||||
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
|
||||
// using multiple blocks for parallel reduction
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0;
|
||||
}
|
||||
__host__ __device__ uint32_t get_sk_total_iters() const
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_sk_tiles() const
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
__host__ __device__ index_t get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
return reduction_start_block_idx + get_sk_tiles();
|
||||
}
|
||||
else
|
||||
return reduction_start_block_idx;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_block_idx() const
|
||||
{
|
||||
// TODO: swizzle block index for better locality
|
||||
return __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
}
|
||||
|
||||
__device__ void
|
||||
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = get_sk_total_iters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t current_iter_length = math::min(
|
||||
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
|
||||
return current_iter_length;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
|
||||
|
||||
__device__ void
|
||||
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
|
||||
{
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
|
||||
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
|
||||
|
||||
// // swizzle tile
|
||||
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
|
||||
? tile_swizzle_sub_m
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
|
||||
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
|
||||
{
|
||||
return get_sk_tiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
|
||||
{
|
||||
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
|
||||
const MDiv& equiv_tiles_) const
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_equiv_tiles_ + rem_;
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
__host__ __device__ uint32_t get_total_acc_buffers() const
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big =
|
||||
get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
|
||||
{
|
||||
// TODO: from big to little
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
|
||||
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
|
||||
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user