From cb13839425e0ec4dfff5b8138104ee3e3183d050 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 24 Jun 2024 08:45:52 +0800 Subject: [PATCH 01/16] layernorm2d forward (#1339) * Add layernorm2d forward * Refind file path * clang format * Exclude ck_tile op from all * use add_executable instead * refactor layernorm2d_fwd example --------- Co-authored-by: carlushuang --- example/ck_tile/02_layernorm2d/CMakeLists.txt | 4 + example/ck_tile/02_layernorm2d/README.md | 22 ++ .../02_layernorm2d/layernorm2d_fwd.cpp | 191 ++++++++++++ .../02_layernorm2d/layernorm2d_fwd.hpp | 30 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/core.hpp | 1 + include/ck_tile/core/numeric/null_type.hpp | 13 + include/ck_tile/host.hpp | 1 + include/ck_tile/host/check_err.hpp | 25 +- .../host/reference/reference_layernorm2d.hpp | 69 +++++ include/ck_tile/ops/layernorm2d.hpp | 9 + .../kernel/layernorm2d_fwd_kernel.hpp | 291 ++++++++++++++++++ .../block_layernorm2d_fwd_problem.hpp | 30 ++ .../pipeline/tile_layernorm2d_fwd_shape.hpp | 35 +++ include/ck_tile/ops/welford.hpp | 8 + .../ops/welford/thread/thread_welford.hpp | 101 ++++++ .../ck_tile/ops/welford/warp/warp_welford.hpp | 154 +++++++++ 17 files changed, 975 insertions(+), 10 deletions(-) create mode 100644 example/ck_tile/02_layernorm2d/CMakeLists.txt create mode 100644 example/ck_tile/02_layernorm2d/README.md create mode 100644 example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp create mode 100644 example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp create mode 100644 include/ck_tile/core/numeric/null_type.hpp create mode 100644 include/ck_tile/host/reference/reference_layernorm2d.hpp create mode 100644 include/ck_tile/ops/layernorm2d.hpp create mode 100644 include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp create mode 100644 include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp create mode 100644 include/ck_tile/ops/welford.hpp create mode 100644 include/ck_tile/ops/welford/thread/thread_welford.hpp create mode 100644 include/ck_tile/ops/welford/warp/warp_welford.hpp diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt new file mode 100644 index 0000000000..bac5f45cd3 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -0,0 +1,4 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md new file mode 100644 index 0000000000..433dad04e6 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/README.md @@ -0,0 +1,22 @@ +# Layernorm2D forward + +This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_layernorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_layernorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp new file mode 100644 index 0000000000..9cbd286104 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -0,0 +1,191 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include + +// Host API implementation +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + using thread_tile = ck_tile::sequence<4, 4>; + using warp_tile = ck_tile::sequence<8, 128>; + using block_tile = ck_tile::sequence<32, 128>; + + using Shape = ck_tile::TileLayernorm2dShape; + + using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; + + using Kernel = ck_tile::Layernorm2dFwd; + + auto kargs = Kernel::MakeKargs( + a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); + + const dim3 grids = Kernel::GridSize(a.M); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "m dimension") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +int main(int argc, char* argv[]) +{ + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({M, N}); + ck_tile::HostTensor gamma_host({N}); + ck_tile::HostTensor beta_host({N}); + + ck_tile::HostTensor y_host_ref({M, N}); + ck_tile::HostTensor y_host_dev({M, N}); + + ck_tile::HostTensor mean_host_ref({M}); + ck_tile::HostTensor invStd_host_ref({M}); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::HostTensor mean_host_dev({M}); + ck_tile::HostTensor invStd_host_dev({M}); +#endif + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); +#endif + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + + layernorm2d_fwd_traits traits{data_type}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + mean_buf.GetDeviceBuffer(), + invStd_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + epsilon, + M, + N}; + + float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << "[" << data_type << "]" + << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" + << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + pass = ck_tile::check_err(y_host_dev, y_host_ref); + +#ifdef SAVE_MEAN_INV_STD + mean_buf.FromDevice(mean_host_dev.data()); + pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + + invStd_buf.FromDevice(invStd_host_dev.data()); + pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); +#endif + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl << std::flush; + + return !pass; +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp new file mode 100644 index 0000000000..4d1aac0994 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/layernorm2d.hpp" +#include + +struct layernorm2d_fwd_traits +{ + std::string data_type; +}; + +struct layernorm2d_fwd_args +{ + const void* p_x; + const void* p_gamma; + const void* p_beta; + void* p_y; + void* p_mean; + void* p_invStd; + float epsilon; + ck_tile::index_t M; + ck_tile::index_t N; +}; + +// host API +float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d2b086e043..995d193f10 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(AFTER ) add_subdirectory(01_fmha) +add_subdirectory(02_layernorm2d) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb490cce4a..4cddf6faa9 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/numeric/null_type.hpp b/include/ck_tile/core/numeric/null_type.hpp new file mode 100644 index 0000000000..8799c0560e --- /dev/null +++ b/include/ck_tile/core/numeric/null_type.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +struct null_type +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 09030fa6df..0e69a925d5 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -18,6 +18,7 @@ #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" +#include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef9b24138..529bfdff25 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -56,8 +56,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -114,8 +115,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -173,8 +175,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; diff --git a/include/ck_tile/host/reference/reference_layernorm2d.hpp b/include/ck_tile/host/reference/reference_layernorm2d.hpp new file mode 100644 index 0000000000..837f52c399 --- /dev/null +++ b/include/ck_tile/host/reference/reference_layernorm2d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_layernorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + const HostTensor& beta_n, + HostTensor& y_m_n, + HostTensor& mean_m, + HostTensor& invStd_m, + ComputeDataType epsilon) +{ + auto layernorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + int count = 0; + ComputeDataType mean = 0; + ComputeDataType variance = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ++count; + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType delta = x - mean; + mean += delta / count; + ComputeDataType delta2 = x - mean; + variance += delta * delta2; + } + + // actual variance + variance = variance / count; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(variance + epsilon); + + if constexpr(!std::is_same_v) + mean_m(m) = ck_tile::type_convert(mean); + + if constexpr(!std::is_same_v) + invStd_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + ComputeDataType beta = ck_tile::type_convert(beta_n(n)); + auto y = (x - mean) * divisor; + y = y * gamma + beta; + + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(layernorm2d_fwd_func, + mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp new file mode 100644 index 0000000000..3b66645ed4 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp new file mode 100644 index 0000000000..4be3e56874 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" + +namespace ck_tile { + +// TODO: Extract some type to wrapper class +template +struct Layernorm2dFwd +{ + using Problem = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = !std::is_same_v; + static constexpr bool kSaveInvStd = !std::is_same_v; + + static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + + static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + const void* p_beta; + + void* p_y; + void* p_mean; + void* p_invStd; + + float epsilon; + + ck_tile::index_t M; + ck_tile::index_t N; + }; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_mean, + void* p_invStd, + float epsilon, + ck_tile::index_t M, + ck_tile::index_t N) + { + return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } + + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + { + constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + + using Lengths = decltype(nDstrSpan.impl_); + + ck_tile::index_t ret = 1; + + ck_tile::static_for<0, Lengths::size(), 1>{}( + [&](auto idx) { ret *= Lengths::template at(idx); }); + + return ret; + } + + template + CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, + const ComputeDataType epsilon) + { + // TODO: Investigate fast inverse square root algorithm with epsilon + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + DistributedTensor out_dstr_tensor; + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + out_dstr_tensor(i_idx) = type_convert(1.0f) / + ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); + }); + + return out_dstr_tensor; + } + + template + CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + MeanDataType* p_mean, + InvStdDataType* p_invStd, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + const auto gamma_n = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto beta_n = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); + + // TODO: padding - handle max_count if N % kNPerBlock != 0 + constexpr auto NPerThread = GetNPerThread(xDstr); + ThreadWelford thread_welford{ + type_convert(NPerThread * N / kNPerBlock)}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + move_tile_window(x_block_window, {0, kNPerBlock}); + } + + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + { + const auto mean_m = make_naive_tensor_view_packed( + p_mean, make_tuple(M), number<32>{}); + + auto mean_block_window = + make_tile_window(mean_m, make_tuple(number{}), {iM}); + + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + } + if constexpr(kSaveInvStd) + { + const auto inv_std_m = make_naive_tensor_view_packed( + p_invStd, make_tuple(M), number<32>{}); + + auto inv_std_block_window = + make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + + store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); + } + + // TODO: Extract normalize pipeline + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {stride_to_right_most_window}); + move_tile_window(beta_block_window, {stride_to_right_most_window}); + move_tile_window(y_block_window, {0, stride_to_right_most_window}); + + // Normalization + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {-kNPerBlock}); + move_tile_window(beta_block_window, {-kNPerBlock}); + move_tile_window(y_block_window, {0, -kNPerBlock}); + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + TwoPassLayernorm2dFwd(static_cast(kargs.p_x), + static_cast(kargs.p_gamma), + static_cast(kargs.p_beta), + static_cast(kargs.p_y), + static_cast(kargs.p_mean), + static_cast(kargs.p_invStd), + static_cast(kargs.epsilon), + kargs.M, + kargs.N); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp new file mode 100644 index 0000000000..5206d36d7d --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct BlockLayernorm2dFwdProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp new file mode 100644 index 0000000000..1ff541d844 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template // Sequence<... +struct TileLayernorm2dShape +{ + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; + + // TODO - kNNumWarps can only be 1 if we don't support cross warp welford + static_assert(kNWarpPerBlock == 1); + + static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp new file mode 100644 index 0000000000..dffaad7501 --- /dev/null +++ b/include/ck_tile/ops/welford.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp new file mode 100644 index 0000000000..2ca9a23657 --- /dev/null +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct ThreadWelford +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE void Update(T& mean, T& var, T x) + { + if(ck_tile::isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + // [CAUSION] - max_count_ is to deal with the padding problem + // max_count_ is depend on caller, eg: naive and splitN welford will have different + // calculation of max_count_ + CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {} + + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x); + }); + } + }); + } + + template + CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor() + { + static_assert(std::is_same_v, "wrong!"); + + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + clear_tile(tensor); + + return tensor; + } + + template + CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor) + { + auto mean_tensor = MakeInitialMeanVarDistributedTensor(); + auto var_tensor = MakeInitialMeanVarDistributedTensor(); + + (*this)(x_tensor, mean_tensor, var_tensor); + + return ck_tile::make_tuple(mean_tensor, var_tensor); + } + + int cur_count_; + int max_count_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp new file mode 100644 index 0000000000..687b61f430 --- /dev/null +++ b/include/ck_tile/ops/welford/warp/warp_welford.hpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct WarpMergeWelford +{ + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE static void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a_ * count_b_over_count; + count_a = count; + } + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + const auto rs_idx = + mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); + + // welford merge + Merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + } + }); + + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + if constexpr(BroadcastLane) + { + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; + + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); + + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); + + // decide whether to update local data with remote data + v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; + v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; + v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; + }); + } + }); + } + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + + if constexpr(GetActualVariance) + var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count; + else + var_tensor.get_thread_buffer()(i) = v_local_var; + + count = v_local_count; + }); + } +}; + +} // namespace ck_tile From 3e9711f0cb1c7ffd3826a93dfa6dd65e98715636 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Tue, 25 Jun 2024 14:37:35 -0700 Subject: [PATCH 02/16] CK Instance Gen (#1145) * Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * fixing block to etile error * Format * Updates * Format * fixed rrr layout * generating a sample JSON file: currently contains includes, prologue/epilogue and instances * version where the json is passed into the instances to generate a key * updated run function to just launch kernel * updated run function: only contains kernel object, json file is updated but still needs to be cleaned up, added front-end API to parse JSON into character buffer * adding in testing files * cleaned up comments, still need to work on including header files * removed unneeded files * removed/commented out JSON implementation * added fusion(prologue/epilogue) into instance generation * working on instance selection * added instance selection, need to fix instance validation * removed block2etile map validity check for testing purposes * test running: failing due to incorrect files/input * all grid descs/ptrs completed, but device file not found * Update test and embed modules * Restore older version * added convolution operation, written test, debugging generated code for compilation * attempting to include CK in host directory: _Float16 error * CK header file issues * slight fix * don't crash when hip can't report total memory * dump generated code to a file * changing sizes * creating tensor descriptors using CK methods: set up grid desc manually, also trying to set up an argument pointer - this needs to be fixed * some fixes to call the device code * separating test files for conv and gemm * completed arg ptr, now have linking errors * clang format fix * resolved linker issues in conv test * remove dependency on libutility from ck * resolved num dim error * properly passing arg ptr, errors with passing typenames: redefinition/redeclaration * undo the commenting of device function * hand created kernel code to find rtc issues * dump the full src to file * resolved redeclaration errors, cleaned up errors for Amber's kernel code * debugging purposes: redeclaration error * config files * resolved errors for NumTensor and redeclaration, formatted version.h * resolved most errors in manually added kernel and my own. error with calling kernel object: overloaded function type * WIP: close to getting kernel compiled * WIP: fixing rtc errors * fixed sequence errors, formatting, still one error with run fcn * yay: kernel compiles and runs * updated templated/generated version to run and compile * minor fixes * working generated example, resolved memory access error due to padding * adding in reference kernel, validation failing against reference * debugging: printing kernel argsz * reduced error in results * debugged reference kernel and output errors, added to generated version, currently debugging prologue function issues * working validation (using reference convolution) with prologue function for both hard-coded and generated version * WIP: create an alt version that creates Argument on the device * wip: added new duplicate files, fixed fusion templating errors from working example, setting up kernel arguments * wip: making necessary methods device code * added grid descs, working on grid pointers, errors with stl numerics * wip: updating kernel args - issue, replacing some std functions * replaced std::accumulate call with temp hardcoded version * wip: args causing memory issue * Construct Argument object inside the kernel and use it to call convolution device function. Code runs and verification passes * adding object file dump * temporary hardcoding of grid size, can remove device op inst + arg ptr * minor fix for grid size * added modified example where arg ptr is created on the device for generated version as well * removed device op instance and arg ptr from modified examples * moving device op file for testing purposes and to properly build CK * commenting out print-outs * adjust compiler args to produce a valid ELF file * temporary removal of validation * reverting compiler args back for working example * retrieve necessary arguments from generated template parameters in correct format * calculating grid size on host-side, still need to clean up process, pass parameters to host functions properly * scaled up factory functions/wrapper structs to implement host-side launch parameter calculations using CK host side functions - in hard-coded example * temporary change to generate ELF format binary object file * removed unecessary code, added comments * formatting fix * cleaned up code, added new tests, restructured library: move helper into CK * refactored launch parameter calculation to be more concise * renamed files and variables for more clarity/uniformity * more code cleaning, removed debug statements * moved majority of my files into codegen directory, running properly * updated Embed.cmake(string_view) in codegen directory * updated host directory to match Embed.cmake as well * added old tests in * updated instance generation methods to be more concise * removed layout from launch parameter calculation * working test * fixed issue with verification, all instances working * updated verification in other tests * removed duplicate matrix padder file, removed code dumps * removed old hard-coded tests * removed old host directory, all files in codegen directory now * fixed copyright in files * commenting out validation * renamed files * made changes for review: fixed copyright, renamed files for clarity, removed comments, refactored code * updated headers * removing duplicate file for fwd conv to gemm, merging with original file * fix building codegen with clang++ directly * resolving build error from conv_fwd_to_gemm * fix for previous error * renaming tests * created common test file * cleaned up code, added comments * renamed device op * fixed typos in comments * removed extra space * code cleanup: resolving Amber's comments * removed wrapper struct for matrix padder, fixed template * cleaned up if statements for better readability --------- Co-authored-by: Paul Co-authored-by: Jing Zhang Co-authored-by: M. Amber Hassaan Co-authored-by: illsilin Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- codegen/CMakeLists.txt | 28 +- codegen/driver/main.cpp | 42 +- .../ck/host/device_gemm_multiple_d.hpp | 2 +- .../host/device_gemm_multiple_d/operation.hpp | 17 +- .../host/device_gemm_multiple_d/problem.hpp | 17 +- .../conv_fwd_op.hpp | 60 ++ .../conv_fwd_problem.hpp | 56 ++ codegen/include/ck/host/headers.hpp | 1 - codegen/include/ck/host/operation/gemm.hpp | 2 +- codegen/include/ck/host/stringutils.hpp | 2 +- codegen/include/ck/host/types.hpp | 18 +- codegen/include/ck/host/utils.hpp | 5 +- codegen/src/device_gemm_multiple_d.cpp | 15 +- ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 57 +- .../device_grouped_conv_fwd_multiple_abd.cpp | 42 + ...wd_multiple_abd_operation_xdl_cshuffle.cpp | 364 ++++++++ codegen/src/headers.cpp | 2 +- codegen/src/types.cpp | 8 + codegen/src/utils.cpp | 2 +- codegen/test/CMakeLists.txt | 14 +- codegen/test/common.hpp | 134 +++ codegen/test/gemm_multiple_d.cpp | 7 +- .../test/grouped_conv_fwd_multiple_d_v1.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v2.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v3.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v4.cpp | 209 +++++ codegen/test/rtc/src/compile_kernel.cpp | 8 + codegen/test/rtc/src/hip.cpp | 6 +- .../ck/tensor_operation/gpu/device/helper.hpp | 359 ++++++++ ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 781 ++++++++++++++++++ .../gpu/device/matrix_padder.hpp | 13 + .../transform_conv_fwd_to_gemm.hpp | 564 +++++++++++++ include/ck/utility/array.hpp | 2 + 33 files changed, 3417 insertions(+), 47 deletions(-) create mode 100644 codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp create mode 100644 codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp create mode 100644 codegen/src/device_grouped_conv_fwd_multiple_abd.cpp create mode 100644 codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp create mode 100644 codegen/test/common.hpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v1.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v2.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v3.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v4.cpp create mode 100644 include/ck/tensor_operation/gpu/device/helper.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 72549c9a4e..d8b22fc943 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.16) -project(composable_kernel_host) +project(composable_kernel_host LANGUAGES CXX HIP) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -12,24 +12,38 @@ find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +add_compile_options(-std=c++17) +find_package(hip) +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") + +# add include directories +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include + ${HIP_INCLUDE_DIRS} + ) + list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) + ${CK_ROOT}/include/ck/*.hpp) message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "RELATIVE: ${CK_ROOT}/include") add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_definitions(-std=c++17) - file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) target_include_directories(ck_host PUBLIC $ diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index dfd513106b..c7d295de94 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -5,24 +5,27 @@ #include #include #include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/stringutils.hpp" using ck::host::Transform; struct Emitters { + // retrieve the hard-coded instances provided, template them, and then store them in a map std::unordered_map()>> m; template - void Register(const std::string& name) + void Register(const std::string& name, const std::string& prologue, const std::string& epilogue) { - m[name] = [] { - auto configs = T::CreateOperations(); + m[name] = [&] { + auto configs = T::CreateOperations(prologue, epilogue); return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); }; } + // takes in an operation instance and uses it to substitute the correct values into the template template static std::string ToTuple(const T& ops) { @@ -31,6 +34,7 @@ struct Emitters return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; } + // Join together all the strings in the map std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } std::vector List() const @@ -43,9 +47,38 @@ int main(int argc, const char* argv[]) { std::string prog = argv[0]; std::vector args(argv + 1, argv + argc); + + // Specify problem type and problem size + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + + // user provided fusion + std::string prologue = ""; + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};)"; + + // Load in operations into the Register Emitters e; e.Register( - "DeviceGemmMultipleD_Xdl_CShuffle"); + "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue); if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { return arg == "-h" or arg == "--help"; @@ -64,6 +97,7 @@ int main(int argc, const char* argv[]) return 0; } + // print out all the instances for the operation that was chosen at the command line for(auto name : args) std::cout << e.Emit(name) << std::endl; diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp index 88e040db53..02c19c88e7 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index f9d39633ac..359da7d8cf 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -14,10 +14,15 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines all values need for an instance of fwd conv struct Operation_Xdl_CShuffle { - static std::vector> CreateOperations(); - static std::vector CreateOperations(const Problem& prob); + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); TensorDesc A{}; TensorDesc B{}; DataType acc = DataType::Float; @@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle std::string a_elem_op = PassThrough; std::string b_elem_op = PassThrough; std::string cde_elem_op = Bilinear; + std::string prologue = ""; + std::string epilogue = ""; std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters operation::TileDesc tile_desc{}; operation::BlockTransferDesc a_block_transfer{}; operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); + // returns a templated instance Solution ToSolution() const; }; diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp index f6dbc2b6e8..f4036328ec 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,11 +12,14 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines the problem specification for a GEMM operation struct Problem { - std::size_t M = 0; - std::size_t N = 0; - std::size_t K = 0; + // dimensions for GEMM operation + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + // layouts for tensors bool TransA = false; bool TransB = false; bool TransE = false; @@ -29,9 +32,13 @@ struct Problem std::string BElementOp = PassThrough; std::string CDEElementOp = PassThrough; + // returns the correct device op file for the operation std::string GetIncludeHeader() const; - std::vector GetSolutions(const std::string& arch) const; + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; }; } // namespace device_gemm_multiple_d diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp new file mode 100644 index 0000000000..5ad1dce176 --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the values needed for an instance of forward convolution and functions to return +// (templated) instances +struct Operation_Conv_Fwd_Xdl_Cshuffle +{ + // returns a vector of instances given the fusion operations, uses default values for problem + // spec + static std::vector + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, provided with a problem spec and fusion operations + static std::vector CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue); + std::size_t NumDim; + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = PassThrough; + std::string prologue = ""; + std::string epilogue = ""; + std::string conv_specialization = + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"; + std::string gemm_specialization = + "ck::tensor_operation::device::GemmSpecialization::MNKPadding"; + // tuning parameters + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + // functions to update fusion operations if they are provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp new file mode 100644 index 0000000000..433f9a8fc9 --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the problem specification for a forward convolution operation +struct Problem_Conv_Fwd +{ + std::size_t NumDim = 0; + // size of a forward convolution operation + std::size_t G = 0; + std::size_t N = 0; + std::size_t C = 0; + std::size_t Hi = 0; + std::size_t Wi = 0; + std::size_t Ho = 0; + std::size_t Wo = 0; + std::size_t K = 0; + std::size_t Y = 0; + std::size_t X = 0; + Layout ALayout = Layout::NHWGC; + Layout BLayout = Layout::GKYXC; + Layout ELayout = Layout::NHWGK; + std::vector DsLayout = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough"; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 3da05baaaf..54f8d9f731 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include #include diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index f587122b05..84ef92f0a0 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 01374b86c8..89c1884d2e 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 23488a66d0..812c073678 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,6 +12,7 @@ namespace ck { namespace host { +// holds the templated instance, substitues values into template from instancess struct Solution { @@ -33,6 +34,7 @@ struct Solution std::unordered_map template_values; }; +// supported data types enum class DataType { Half, @@ -40,22 +42,28 @@ enum class DataType Int8, Int32 }; - std::string ToString(DataType dt); +// supported layouts: gemm and fwd conv enum class Layout { Row, - Column + Column, + GKYXC, + GKCYX, + GNHWK, + GNHWC, + NHWGC, + NHWGK }; - std::string ToString(Layout dl); +Layout ToLayout(bool Trans); // returns the layout for gemm +// supported GEMM types enum class GemmType { Default }; - std::string ToString(GemmType gt); struct TensorDesc diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp index e8785a456f..21926814f1 100644 --- a/codegen/include/ck/host/utils.hpp +++ b/codegen/include/ck/host/utils.hpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include +#include namespace ck { namespace host { @@ -12,6 +14,5 @@ namespace host { std::size_t integer_divide_ceil(std::size_t x, std::size_t y); const std::unordered_set& get_xdlop_archs(); - } // namespace host } // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp index ec25afc0f9..44bc051a8b 100644 --- a/codegen/src/device_gemm_multiple_d.cpp +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" @@ -11,23 +11,28 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// return the relevant device op file based on the operation std::string Problem::GetIncludeHeader() const { return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; } -std::vector Problem::GetSolutions(const std::string& arch) const +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const { if(get_xdlop_archs().count(arch) == 0) return {}; - auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this); + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances std::vector result; std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { - return op.ToSolution(); + return op.ToSolution(); // template instance with correct values }); return result; } } // namespace device_gemm_multiple_d } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index 9e397497ee..a2e8eccbf1 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -10,6 +10,7 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// calculate appropriate Gemm Specification based on input tensor dimensions static std::string GetGemmSpec(const std::size_t m, const std::size_t n, const std::size_t k, @@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m, return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; } +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } -std::vector Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) { std::vector result; @@ -155,6 +187,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con // clang-format on }; + // choose correct arrangement of tuning parameters based on the layout of each tensor const auto a_block_descriptions = prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; const auto b_block_descriptions = @@ -165,6 +198,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); + // Put all values together into a single operation > store into the result vector for(std::size_t i = 0; i < tile_descriptions.size(); i++) { Operation_Xdl_CShuffle x; @@ -188,12 +222,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con x.tile_desc.m_per_block, x.tile_desc.n_per_block, x.tile_desc.k_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); result.push_back(x); } return result; } -std::vector> Operation_Xdl_CShuffle::CreateOperations() +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) { std::vector problems; for(bool TransA : {true, false}) @@ -204,7 +243,8 @@ std::vector> Operation_Xdl_CShuffle::CreateO prob.TransB = TransB; problems.push_back(prob); } - return Transform(problems, [](const Problem& p) { return CreateOperations(p); }); + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); } static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = @@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferScalarPerVector_NPerBlock}>"; +// use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const { std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, {"LayoutA", ToString(this->A.layout)}, {"LayoutB", ToString(this->B.layout)}, {"LayoutDs", diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp new file mode 100644 index 0000000000..c689e5ec95 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/utils.hpp" +#include +#include + +namespace ck { +namespace host { +namespace conv { + +// return the relevant device op file based on the operation +// NOTE: this is a modified version of the original CK file that calls the kernel from a device +// function and makes the Argument class accessible on the device +std::string Problem_Conv_Fwd::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/" + "codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"; +} + +// return vector of forward convolution instances when provided with a problem instance +std::vector Problem_Conv_Fwd::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + *this, prologue, epilogue); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..94161a76d9 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace conv { + +// calculate appropriate Gemm Specification based on input tensor dimensions +// NOTE: in CK, MNKPadding is always used for forward convolution +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + // clang-format on + }; + + std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector b_block_descriptions = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1} + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8} + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Conv_Fwd_Xdl_Cshuffle x; + x.NumDim = prob.NumDim; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, prob.ALayout}; + x.B = TensorDesc{prob.BDataType, prob.BLayout}; + x.E = TensorDesc{prob.EDataType, prob.ELayout}; + x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { + return TensorDesc{dt, lo}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values +std::vector +Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue, + const std::string& epilogue) +{ + Problem_Conv_Fwd prob; + return CreateOperations(prob, prologue, epilogue); +} + +static const char* const CopyDevice_ConvTemplate = + R"( +${Prologue} +${Epilogue} + +using CDEElementOp = Epilogue; +using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>; + +constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors(); +constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors(); + +extern "C" __global__ void run_${name}( + const ${ADataType}* in_dev, + const ${BDataType}* wei_dev, + ${EDataType}* __restrict__ out_dev, + ck::Array in_lengths, + ck::Array in_strides, + ck::Array wei_lengths, + ck::Array wei_strides, + ck::Array out_lengths, + ck::Array out_strides, + ck::Array conv_filter_strides, + ck::Array conv_filter_dilations, + ck::Array input_left_pads, + ck::Array input_right_pads, + const ${AElementwiseOperation} a_element_op, + const ${BElementwiseOperation} b_element_op, + const ${CDEElementwiseOperation} cde_element_op +){ + + + auto arg = DeviceConv::Argument(in_dev, + wei_dev, + ck::Array{}, + out_dev, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + ck::Array, 0>{}, + ck::Array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + ${AElementwiseOperation}{}, + ${BElementwiseOperation}{}, + ${CDEElementwiseOperation}{1.0f, 1.0f}); + + constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); + + // GridwiseGemm + using GridwiseGemm = DeviceConv::GridwiseGemm; + + static constexpr auto I0 = ck::Number<0>{}; + + ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ${ADataType}*, + const ${BDataType}*, + typename GridwiseGemm::DsGridPointer, + ${EDataType}, + ${AElementwiseOperation}, + ${BElementwiseOperation}, + ${CDEElementwiseOperation}, + DeviceConv::AGridDesc_AK0_M_AK1, + DeviceConv::BGridDesc_BK0_N_BK1, + DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::Block2ETileMap, + ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch, + ck::integral_constant{}, + false, + false> + ( + arg.p_as_grid_.At(I0), + arg.p_bs_grid_.At(I0), + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_ + ); + +} +)"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"NumDim", std::to_string(this->NumDim)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"ComputeDataType", ToString(this->A.element)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"Prologue", this->prologue}, + {"Epilogue", this->epilogue}, + {"ConvSpecialization", this->conv_specialization}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)}; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 6fcb94cdbd..f685aca044 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -14,4 +14,4 @@ std::unordered_map GetHeaders() } } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index d43df73f33..a8a8b10c04 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -29,12 +29,20 @@ std::string ToString(DataType dt) throw std::runtime_error("Incorrect data type"); } +Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + std::string ToString(Layout dl) { switch(dl) { case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX"; + case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC"; + case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK"; + case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC"; + case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC"; + case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK"; } throw std::runtime_error("Incorrect layout"); } diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index cd6700c489..19627d4cf6 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/utils.hpp" diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 897cce1c94..f891286019 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,11 +1,13 @@ - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) - file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) foreach(TEST_SRC ${TEST_SRCS}) -get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) -rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) -target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) -target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) + # target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a) + target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/include) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) endforeach() diff --git a/codegen/test/common.hpp b/codegen/test/common.hpp new file mode 100644 index 0000000000..99d4c64973 --- /dev/null +++ b/codegen/test/common.hpp @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +std::size_t GetSize(V mLens, V mStrides) +{ + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.Size(); ++i) + { + if(mLens[i] == 0) + continue; + + space += (mLens[i] - 1) * mStrides[i]; + } + return space; +} + +template +rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + rtc::buffer result(space); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + // std::fill(result.begin(), result.end(), 1); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 17b659993a..bd7ef463fb 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using half = _Float16; // using half = __fp16; @@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel) auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto c = to_gpu(generate_buffer(1024 * 1024, 2)); - for(auto solution : prob.GetSolutions("gfx90a")) + std::string epilogue = ""; + std::string prologue = ""; + + for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) { auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, @@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel) auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * ck::host::integer_divide_ceil(prob.N, n_per_block); k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + CHECK(report(solution, check(rtc::from_gpu(c)))); } } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp new file mode 100644 index 0000000000..3c477692e5 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include "common.hpp" +#include + +// Need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp new file mode 100644 index 0000000000..ec9bd2b781 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include + +// need this for validation +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp new file mode 100644 index 0000000000..9850184c5e --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp new file mode 100644 index 0000000000..907f744db4 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 7ea55b9328..d84ebf4de9 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer) } std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } +// TODO: undo after extracting the codeobj +// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } kernel compile_kernel(const std::vector& srcs, compile_options options) { @@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector& srcs, compile_options options auto obj = read_buffer(out_path.string()); + std::ofstream ofh("obj.o", std::ios::binary); + for(auto i : obj) + ofh << i; + ofh.close(); + // int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str()); + // assert(s == 0); return kernel{obj.data(), options.kernel_name}; } diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 10e38c9adb..747f83e3ba 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace rtc { @@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory() size_t total; auto status = hipMemGetInfo(&free, &total); if(status != hipSuccess) - throw std::runtime_error("Failed getting available memory: " + hip_error(status)); + { + std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl; + return (8ull * 1024ull * 1024ull * 1024ull); + } return free; } diff --git a/include/ck/tensor_operation/gpu/device/helper.hpp b/include/ck/tensor_operation/gpu/device/helper.hpp new file mode 100644 index 0000000000..c52566509f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/helper.hpp @@ -0,0 +1,359 @@ +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include +#include + +// functions to return the corresponding structs based on generated template parameters + +using layouts = std::variant; +// return the layout type: currently this is the only type supported in MIOpen +auto layout_type(std::string type) +{ + if(type == "ck::tensor_layout::convolution::NHWGK") + { + return ck::tensor_layout::convolution::NHWGK{}; + } + throw std::runtime_error("Incorrect layout"); +} +// return the right gemm spec based on the generated template parameters +ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type) +{ + if(type == "ck::tensor_operation::device::GemmSpecialization::Default") + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding") + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + throw std::runtime_error("Incorrect gemm spec: " + type); +} + +// return the type of convolution +ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type) +{ + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + } + if(type == + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + } + throw std::runtime_error("Incorrect conv spec: " + type); +} + +// Function to call on MatrixPadder via a wrapper struct +// NOTE: CK only uses MNKPadding for forward convolution +template +auto pad(ck::index_t mpb, + ck::index_t npb, + ck::index_t kpb, + ck::tensor_operation::device::GemmSpecialization gemm, + CDesc_MRaw_NRaw conv) +{ + if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding) + { + ck::tensor_operation::device::MatrixPadder< + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + ck::index_t, + ck::index_t, + ck::index_t> + a; + a.MPerTile_ = mpb; + a.NPerTile_ = npb; + a.KPerTile_ = kpb; + auto tmp = grid_desc(a, conv); + return tmp; + } + throw std::runtime_error("Incorrect template parameters, check gemm spec"); +} + +// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num +// dims +// FIXME: add a way to properly pass in the layout +auto transform_conv(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_3d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_1d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect dims or conv spec"); +} + +template +auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder) +{ + if(m_per_block == 32 && n_per_block == 64) + { + auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 32 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 256) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 256 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + throw std::runtime_error("Incorrect template parameters"); +} + +// wrapper functions by dims to get grid size - uses above 3 functions +// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions +auto get_launch_params_1d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params_3d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000..7ef4e7f184 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ + + device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AsPointer, // tuples if multi AB, pointers if no + BsPointer, + DsPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfBatch, + HasMainKBlockLoop, + isMultiA, + isMultiB>(p_as_grid, + p_bs_grid, + p_ds_grid, + *p_e_grid, + a_element_op, + b_element_op, + cde_element_op, + batch_count, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + compute_ptr_offset_of_batch); +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument + { + __device__ __host__ Argument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + ck::Array a_g_n_c_wis_lengths_; + ck::Array a_g_n_c_wis_strides_; + ck::Array b_g_k_c_xs_lengths_; + ck::Array b_g_k_c_xs_strides_; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_; + ck::Array e_g_n_k_wos_lengths_; + ck::Array e_g_n_k_wos_strides_; + ck::Array conv_filter_strides_; + ck::Array conv_filter_dilations_; + ck::Array input_left_pads_; + ck::Array input_right_pads_; + }; + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index c66d2fc516..0294153147 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder +auto grid_desc(MatrixPadder matrix_padder, + CDesc_MRaw_NRaw conv_desc) +{ + auto res = matrix_padder.PadCDescriptor_M_N(conv_desc); + return res; +} // M/N/KPerTileType could be index_t or Number<> template +__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init) +{ + for(ForwardIterator x = first; x != first + count; x++) + { + init *= *x; + } + return init; +} + template struct TransformConvFwdToGemm { @@ -607,6 +618,559 @@ struct TransformConvFwdToGemm return out_gemmm_gemmn_desc; } + + // Overloaded functions for hipRTC purposes + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); + + const auto wei_gemmn_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + + return wei_gemmn_gemmk_desc; + } + + template < + typename BLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); + + const index_t KStride = b_g_k_c_xs_strides[1]; + const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + + return out_gemmm_gemmn_desc; + } + + template < + typename CLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const auto KStride = I1; + const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); + + return out_gemmm_gemmn_desc; + } + + // for output bias + template , + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); + + return out_gemmm_gemmn_desc; + } +}; + +// wrapper class to call member functions on TransformConvToGemm struct at runtime +// TODO: figure out aq way to properly pass in layout as an argument +struct TransformConv +{ + TransformConv() {} + + template + auto + transform_func(ck::Array out_lengths, + ck::Array out_strides, + TransformConvFwdToGemm conv_fwd_to_gemm) + { + if(NDimSpatial == 2) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 3) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 1) + { + return conv_fwd_to_gemm.template MakeCDescriptor_M_N( + out_lengths, out_strides); + } + } }; } // namespace tensor_operation diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index f63ce5e5a0..5366c56a9d 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -36,6 +36,8 @@ struct Array return *this; } + __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } + __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } }; // empty Array From 0cb2e06ddcdf6b1414cbead4262b76b5d4391e93 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 26 Jun 2024 17:41:15 +0800 Subject: [PATCH 03/16] [CK_TILE] fmha forward split-kv + combine kernels (#1338) * FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm * Add num_splits option and dummy split-kv api method * Generate fmha_fwd_splitkv() * Add SplitKV kernel codegen logics * Add SplitKV combine kernel codegen logics * Fix mismatched return type * Clean-up code * Replace sentinel value before storing * Fix wrong layout of LSE/LSEacc/Oacc * Format codes * Fix o_acc memory error * Fix wrong kBlockSize used in policy * Reduce # of combine kernels * Fix split-kv combine kernel name * Fix wrong LDS indexing logics * Fix wrong loop counter step logic * Undo vector size changes * Remove no-longer used field * Remove in-consistent comment * Remove debug statements in example * Remove more debug statements * Add constness to local variables * Clearn up generate.py * Fix unstable clang-format comment * Remove unused include directive * Use shorter template parameter name * Enable non-split-kv blobs * Update license date * Print num_splits conditionally * Undo disabling data types * Remove unnessary tile size for fp8 * Fix wrong pipeline args for fp8 * Fix example output format * Remove more debug code in combine pipeline * Add stride kernel arguments for LSE/O acc workspace * Re-order split-kv pipeline call operator arguments * Pass LSE/O strides in kernel argument * Re-order pipeline call operator arguments * Use tensor_descriptor to locate LSEacc elements * Support providing invalid element for tensor view * Set invalid element value for LSEacc tensor view * Remove hand-written store_tile() code * Remove necessary value-overwrite logic * Add transposed lds descriptor * Support load_tile() for tile_window_with_static_lengths<> * Undo removing necessary value-overwrite logic * Use read descriptor to locate lds elements * Simplify pipeline source code * Add constraint to kMaxSplits * Default use kMaxSplits=64 in generate.py * Revert "Add constraint to kMaxSplits" This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118. * Revert "Default use kMaxSplits=64 in generate.py" This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3. * Decide alignment by the padding parameter * Remove no-longer used utility functions * Remove not-working code * Add comment & remove no-longer used code * Fix computation errors * Add heuristic to override num_splits option * Add constraint to kMaxSplits * Fix compilation error * Clean up pipeline code * Wrap pointer access as lambda function * Rename confusing methods * Use kLogMasSplits as template parameter * Finish splitkv combine kernel codegen * Update kMaxSplits limit * Use smaller kM0 for splitkv combine kernel * Ignore droupout flag in splitkv pipeline * Unify flag usage * Add back flag kStoreLSE * Merge lambda calls in pipeline * Fix compilation errors * Avoid all empty splits * Always check for empty loop in splitkv pipelines * Re-order parameters * Remove redundant p_drop option check * Add traits/problem for fwd splitkv kernel * Conditionally enable uneven split boundary checks * Add comment for the splitkv traits field * Change even split criteria * Re-order statements * Refine occupancy value for hdim=128&256 * Refine occupancy value for hdim=32&64 * Remove redundant kernel argument * Separate fmha bwd codegen logics * Separate fmha fwd codegen logics * Remove redundant direction parameter in fwd&bwd codegen logics * Support generate multiple APIs for an example * Let 'api' an alias of 'direction' option * Remove choices for the 'direction' option * Use dictionary to config all the functions * Move fmha splitkv codegen logics to other file * Add fwd_splitkv api for tile_example_fmha_fwd --------- Co-authored-by: danyao12 Co-authored-by: carlushuang Co-authored-by: rocking Co-authored-by: Jing Zhang --- example/ck_tile/01_fmha/CMakeLists.txt | 8 +- example/ck_tile/01_fmha/codegen/__init__.py | 0 .../ck_tile/01_fmha/codegen/cmake_config.py | 5 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 92 ++ .../ck_tile/01_fmha/codegen/ops/__init__.py | 0 .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 611 ++++++++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 498 +++++++ .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 671 +++++++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 170 ++- example/ck_tile/01_fmha/fmha_fwd.hpp | 215 +++ example/ck_tile/01_fmha/generate.py | 1223 +---------------- include/ck_tile/ops/fmha.hpp | 10 + .../ck_tile/ops/fmha/block/block_masking.hpp | 17 + .../fmha_fwd_splitkv_combine_kernel.hpp | 455 ++++++ ...a_fwd_splitkv_combine_tile_partitioner.hpp | 49 + .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 913 ++++++++++++ .../fmha_fwd_splitkv_tile_partitioner.hpp | 53 + ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 314 +++++ ...plitkv_combine_pipeline_default_policy.hpp | 175 +++ ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 666 +++++++++ ...ha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp | 770 +++++++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 19 + ...litkv_pipeline_qr_ks_vs_default_policy.hpp | 19 + .../pipeline/block_fmha_pipeline_problem.hpp | 65 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 44 + 25 files changed, 5861 insertions(+), 1201 deletions(-) create mode 100644 example/ck_tile/01_fmha/codegen/__init__.py create mode 100644 example/ck_tile/01_fmha/codegen/cmake_config.py create mode 100644 example/ck_tile/01_fmha/codegen/cpp_symbol_map.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/__init__.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e324f85ed8..e30e9e793c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,12 +1,12 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt ) # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") diff --git a/example/ck_tile/01_fmha/codegen/__init__.py b/example/ck_tile/01_fmha/codegen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py new file mode 100644 index 0000000000..03ebfd6702 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py new file mode 100644 index 0000000000..d3d215f7f5 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +DTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp8" : "ck_tile::fp8_t" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/__init__.py b/example/ck_tile/01_fmha/codegen/ops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py new file mode 100644 index 0000000000..0160915a54 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -0,0 +1,611 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_dk_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, false>>; + +using fmha_bwd_dv_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, false>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_dk_epilogue_{F_idx}, + fmha_bwd_dv_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py new file mode 100644 index 0000000000..1486671f6b --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py new file mode 100644 index 0000000000..419fbaaea8 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -0,0 +1,671 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdTileSize, + FmhaFwdApiTrait, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_SPLITKV_PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", +} + +FMHA_FWD_SPLITKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_mask_{F_idx} = {F_mask}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape = ck_tile::TileFmhaShape; + +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + kHasUnevenSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait>; + +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if constexpr({F_mode} == false) {{ // batch mode + if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + kernel_runner::run(s, a); + }} else {{ + kernel_runner::run(s, a); + }} + }} else {{ + kernel_runner::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_get_name_() +{{ + using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, + {F_dvpad}, + {F_lse}, + {F_squant}, + kLogMaxSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_hdim}, + {F_bm0}, + {F_bn1}, + {F_mode}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVCombineKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, + {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if (a.num_splits <= 16) {{ + kernel_runner<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + kernel_runner<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + kernel_runner<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + kernel_runner<7>::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API=""" +#include + +template +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); +}} + +float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +@dataclass +class FmhaFwdSplitKVCombinePipeline: + tag : str + + F_spad : str # true/false + F_dvpad : str # + F_lse : str # + F_squant : str # + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}' + if pn != '' : n += f'_{pn}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdSplitKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdSplitKVCombineTileSize: + F_bm0 : int # tile size along q seqlen + F_bn1 : int # tile size along v head_dim + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdSplitKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdSplitKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +@dataclass +class FmhaFwdSplitKVCombineKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdSplitKVCombineTileSize + F_pipeline : FmhaFwdSplitKVCombinePipeline + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn1 = self.F_tile.F_bn1, + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_mode = MODE_MAP[self.F_mode]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + else: + return None + +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # splitkv kernel donot support dropout + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): + if hdim == 256: + # if True: + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVCombinePipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + else: + assert False + return pipelines + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + gen.append(k) + + return gen + +def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME + file_path.write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_splitkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5f887f0655..28f7905734 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[]) .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -155,6 +158,106 @@ auto get_elimit(std::string init_method) } } +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + efficiency.push_back(0.f); + } + else + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + const int kN1 = hdim_v; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + } + + return num_splits; +} + +float fmha_fwd_dispatch(fmha_fwd_traits traits, + fmha_fwd_args args, + const ck_tile::stream_config& config) +{ + if(1 < args.num_splits) + { + return fmha_fwd_splitkv(traits, args, config); + } + else + { + return fmha_fwd(traits, args, config); + } +} + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } + int num_splits = arg_parser.get_int("num_splits"); + int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); @@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return false; + } + auto get_lengths = [&](bool permute, ck_tile::index_t b /*batch*/, ck_tile::index_t h /*nhead*/, @@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{batch, nhead}) : std::array{1, 1}); - // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + ck_tile::HostTensor lse_acc_host( + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits + ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // self define lse data layout as [batch, nhead, max_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); @@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); @@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", mask:" << mask << ", v:" << vlayout; + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + std::cout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = hdim_v; const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; + const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), @@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() : bias_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(), + lse_acc_buf.GetDeviceBuffer(), + o_acc_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser) hdim_v, nhead, nhead_k, + num_splits, scale_s, scale_p, scale_o, @@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, stride_randval, + stride_o_acc, stride_o, nhead_stride_q, nhead_stride_k, @@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser) nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, + nhead_stride_lse_acc, + nhead_stride_o_acc, nhead_stride_o, batch_stride_q, batch_stride_k, @@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_bias, batch_stride_randval, batch_stride_lse, + batch_stride_lse_acc, + batch_stride_o_acc, batch_stride_o, + split_stride_lse_acc, + split_stride_o_acc, mask.left, mask.right, static_cast(mask.type), @@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser) {drop_seed, drop_offset}}; }(); - float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); if(ave_time < 0) { @@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser) lse_host_result.ForEach( [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); - bool lse_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); - pass &= lse_pass; + pass &= cur_pass; if(!cur_pass) { std::cerr << "LSE mismatch found at batch: " << wb << std::endl diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3594f61db9..ee932ce5d9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -93,6 +93,8 @@ struct fmha_fwd_args const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -106,6 +108,7 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; float scale_s; float scale_p; float scale_o; @@ -114,6 +117,7 @@ struct fmha_fwd_args ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_randval; + ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; @@ -121,6 +125,8 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -128,7 +134,11 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; @@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -298,3 +512,4 @@ struct fmha_fwd_traits // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index e0b4b65595..27347b4476 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -3,1214 +3,62 @@ # generate kernel instances to speed up compilation import argparse -import itertools +from enum import IntEnum from pathlib import Path -from typing import List, Optional, Tuple -from dataclasses import dataclass -import copy -import fnmatch +from typing import List, Optional -DTYPE_MAP = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp8" : "ck_tile::fp8_t" +from codegen.cmake_config import * +from codegen.ops import ( + fmha_fwd, + fmha_fwd_splitkv, + fmha_bwd +) + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + +handlers = { + 'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), + 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), + 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), } -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" -} - -MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", -} - -MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" -} - -BIAS_MAP = { - "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" -} - -# TODO: this is ugly -BIAS_CHECK_MAP = { - "no" : "bias_enum::no_bias", - "bias" : "bias_enum::elementwise_bias", - "alibi" : "bias_enum::alibi" -} - -MODE_MAP = { - "batch" : "false", - "group" : "true" -} - -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} - -PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", -} - -PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", -} - -BOOL_MAP = { - "t" : "true", - "f" : "false" -} - -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", -} - -GEN_DIR = "" # in Cmake, have to generate files in same folder - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_squant}, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; - -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - -#include - -template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" -FMHA_FWD_API=""" -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} -""" -MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", -} - -MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", -} - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_(s, a); - }} -""" - -def get_mask_map(mask : str): - if mask == "generic": - return MASK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_MAP - else: - assert False - return None - -def get_mask_check_map(mask : str): - if mask == "generic": - return MASK_CHECK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_CHECK_MAP - else: - assert False - return None - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0blen : int - vlayout : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : n += '_lse' - if self.F_dropout == 't' : n += '_dropout' - if self.F_squant == 't' : n += '_squant' - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm : int # number of warps along q seqlen (block warps) - F_rn : int # number of warps along k seqlen(not used) - F_rk : int # number of warps along gemm-k(not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ - f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: - if direction == 'fwd': - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) - } - else: - return None - else: - return None - -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): - d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - k = FmhaFwdKernel(direction=direction, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != None: - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -BWD_DQDKDV_PIPELINE_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", -} - -BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", -} - -FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_bwd.hpp" -""" - -FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; -using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; -using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; -using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape -// G0&G2 -> GSdP -// G1&G3 -> GdKV -// G4 -> GdQ -using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; - -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - {F_dbias}, - false, - {F_dropout}, - false, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< - typename FmhaBwdTypeConfig::QDataType, - typename FmhaBwdTypeConfig::KDataType, - typename FmhaBwdTypeConfig::VDataType, - typename FmhaBwdTypeConfig::GemmDataType, - typename FmhaBwdTypeConfig::LSEDataType, - typename FmhaBwdTypeConfig::AccDataType, - typename FmhaBwdTypeConfig::DDataType, - typename FmhaBwdTypeConfig::BiasDataType, - typename FmhaBwdTypeConfig::RandValOutputDataType, - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::QGradDataType, - typename FmhaBwdTypeConfig::KGradDataType, - typename FmhaBwdTypeConfig::VGradDataType, - typename FmhaBwdTypeConfig::BiasGradDataType, - fmha_bwd_shape_{F_idx}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_bwd_trait_{F_idx}>; - -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< - fmha_bwd_pipeline_problem_{F_idx}>; - -using fmha_bwd_dk_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - false, false>>; - -using fmha_bwd_dv_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - false, false>>; - -using fmha_bwd_dq_dk_dv_kernel_{F_idx} = - ck_tile::FmhaBwdDQDKDVKernel, - fmha_bwd_pipeline_{F_idx}, - fmha_bwd_dk_epilogue_{F_idx}, - fmha_bwd_dv_epilogue_{F_idx}>; - -using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - -#include - -template<> -float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} - -template<> -std::string fmha_bwd_dq_dk_dv_get_name_() -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" -FMHA_BWD_API=""" -#include - -template -float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} - ); -}} - -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - r = fmha_bwd_(s, a); - return r; - }} -""" - -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 256 == 0' # BlockSize - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) - -# GEMM0: Q@K=S^T -# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) -# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) -# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) -# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) -# Is it necessary to distinguish between K0~K4? -@dataclass -class FmhaBwdDQDKDVTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along gemm0 unroll(F_bhdq) - F_bk1 : int # tile size along gemm1 unroll(F_bm0) - F_bk2 : int # tile size along gemm2 unroll(F_bhdv) - F_bk3 : int # tile size along gemm3 unroll(F_bm0) - F_bk4 : int # tile size along gemm4 unroll(F_bn0) - F_bhdq : int # q head_dim - F_bhdv : int # v head_dim - F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 - F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 - F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 - F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 - F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 - F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 - F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 - F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 - F_rk2 : int # number of warps along gemm-k (not used) in gemm4 - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" - -@dataclass -class FmhaBwdDQDKDVKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # - F_dbias : str # - F_dropout : str # - F_mask : str # value from MASK_MAP - F_mode : str # value from MODE_MAP - F_pipeline : str - mask_impl : str - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bk1 = self.F_tile.F_bk1, - F_bk2 = self.F_tile.F_bk2, - F_bk3 = self.F_tile.F_bk3, - F_bk4 = self.F_tile.F_bk4, - F_bhdq = self.F_tile.F_bhdq, - F_bhdv = self.F_tile.F_bhdv, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_rm2 = self.F_tile.F_rm2, - F_rn2 = self.F_tile.F_rn2, - F_rk2 = self.F_tile.F_rk2, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BIAS_MAP[self.F_bias], - F_dbias = BOOL_MAP[self.F_dbias], - F_dropout = BOOL_MAP[self.F_dropout], - F_occupancy = self.F_tile.F_occupancy, - F_mask = get_mask_map(self.mask_impl)[self.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name - if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' - if self.F_dbias == 't' : n += '_dbias' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_dropout == 't' : n += '_dropout' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad) - -# TODO: design a more practical way to do it -# this is current supported tile size & pipeline. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: - if direction == 'bwd': - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "ks_vr"] - } - else: - return None - else: - return None - -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl) - if kernel_filter != None: - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, - {F_dvpad}, - {F_occupancy}>; - -using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 256, - {F_hdim}, - {F_mode}, - fmha_bwd_dot_do_o_trait_{F_idx}>; - -using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< - fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; - -using fmha_bwd_dot_do_o_kernel_{F_idx} = - ck_tile::FmhaBwdOGradDotOKernel, - fmha_bwd_dot_do_o_{F_idx}>; - -using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; - -#include - -template<> -float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} - -template<> -std::string fmha_bwd_dot_do_o_get_name_() -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -@dataclass -class FmhaBwdOGradDotOKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_spad : str # true/false - F_dvpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_spad = BOOL_MAP[self.F_spad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - -def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - gen.append(k) - - return gen - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) - -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - if direction == 'fwd': - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - else: - kernels = get_bwd_dot_do_o_blobs() - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) + + for api in api_list: + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) - with file_path.open('a') as f: - if direction == 'fwd': - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") - else: - kernels = get_bwd_dot_do_o_blobs() - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") + + for api in api_list: + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", - description="gen api for CK fmha kernel", + description="gen API for CK fmha kernel", ) parser.add_argument( "-d", - "--direction", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", default='fwd', - choices=['fwd', 'bwd'], required=False, - help="choose the direction of kernels(default: fwd)" + help="supply API(s) to generate (default: fwd). separated by comma." ) parser.add_argument( "-o", @@ -1251,7 +99,8 @@ if __name__ == "__main__": ) args = parser.parse_args() + api_list = args.direction.split(',') if args.list_blobs is not None: - list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) \ No newline at end of file diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5684868306..057d2b11ff 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -10,6 +10,10 @@ #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" @@ -22,6 +26,12 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index ce8493663f..c022edf723 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, x_total / num_splits); + const index_t split_start = x_per_split * i_split; + const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp new file mode 100644 index 0000000000..6f4313d5b6 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -0,0 +1,455 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineKernel +{ + using TilePartitioner = remove_cvref_t; + using FmhaPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(FmhaPipeline::kM0) + "x" + + _TS_(FmhaPipeline::kN1) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + _SS_(FmhaPipeline::name) + + (pn.empty() ? "" : "_" + pn) + + (kStoreLSE ? "_lse" : "" ) + + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* lse_acc_ptr; + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + ck_tile::index_t num_splits; + + ck_tile::index_t row_stride_o_acc; + ck_tile::index_t row_stride_o; + + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct Fp8StaticQuantKargs + { + float scale_o; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_o; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + seqlen_q, + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + batch_stride_o}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + reinterpret_cast(seqstart_q_ptr)}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.row_stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc; + const OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // LSEacc/Oacc DRAM and DRAM windows + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.num_splits, kargs.seqlen_q), + make_tuple(kargs.split_stride_lse_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), + number{}, + number<1>{}); + + auto o_acc_dram_view = pad_tensor_view( + o_acc_dram_naive, + make_tuple(number<1>{}, number{}, number{}), + sequence{}); + + const index_t padded_max_seqlen_q = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; + const index_t padded_hdim_v = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; + + return transform_tensor_view( + o_acc_dram_view, + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_pass_through_transform(padded_hdim_v)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); + + auto lse_acc_dram_window = make_tile_window( + lse_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {0, i_m0}); + + auto o_acc_dram_window = make_tile_window( + o_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {i_m0, i_n1}); + + // LSE DRAM window + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + identity{}, // lse_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + else + { + return FmhaPipeline{}(lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + }(); + + // O DRAM and DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.row_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp new file mode 100644 index 0000000000..9f04843a39 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineTilePartitioner +{ + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp new file mode 100644 index 0000000000..45ed185ada --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -0,0 +1,913 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdSplitKVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using OaccDataType = remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + ck_tile::index_t num_splits; + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o_acc; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct BatchModeBiasKargs : CommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct AlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct MaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct Fp8StaticQuantKargs + { + float scale_p; + }; + + struct CommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct BatchModeDropoutKargs : CommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse acc + auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { + constexpr auto lse_acc_dram_window_lengths = make_tuple(number{}); + LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse_acc + + batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; + + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = + make_naive_tensor_view(lse_acc_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); + }(); + + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + auto o_acc_tile = [&, i_split_ = i_split]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_acc_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + identity{}, // o_acc_element_func + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_acc_dram_window, + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.hdim_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram_window = + make_tile_window(o_acc_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp new file mode 100644 index 0000000000..aec37cb36f --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * + ck_tile::integer_divide_ceil(hdim_v, kN1), + nhead * num_splits, + batch_size); + } + + CK_TILE_DEVICE auto + operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1); + const auto [i_nhead, i_split] = f(blockIdx.y, num_splits); + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp new file mode 100644 index 0000000000..7efdb798cb --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { +namespace detail { +template +struct log2; + +template <> +struct log2<16> : std::integral_constant +{ +}; + +template <> +struct log2<32> : std::integral_constant +{ +}; + +template <> +struct log2<64> : std::integral_constant +{ +}; + +template <> +struct log2<128> : std::integral_constant +{ +}; +} // namespace detail + +template +struct BlockFmhaFwdSplitKVCombinePipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kHeadDimV = Problem::kHeadDimV; + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN1 = Problem::kN1; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr index_t kMaxSplits = Problem::kMaxSplits; + + static constexpr index_t kAlignmentLSE = + kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE(); + static constexpr index_t kAlignmentLSEacc = kAlignmentLSE; + + static constexpr index_t kAlignmentOacc = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kHeadDimV <= 32) + { + constexpr std::array occupancy{3, 3, 3, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 128) + { + constexpr std::array occupancy{3, 3, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 256) + { + constexpr std::array occupancy{2, 2, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + } + }(); + + static constexpr const char* name = "unused"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, + const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, + const LSEElementFunction& lse_element_func, + const OaccElementFunction& o_acc_element_func, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + // lse_acc tile in LDS + LSEDataType* lse_acc_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor()]( + index_t row, index_t col) -> LSEDataType& { + return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))]; + }; + + auto lse_acc_lds_write_window = [&]() { + auto view = make_tensor_view( + lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor()); + return make_tile_window(view, make_tuple(number{}, number{}), {0, 0}); + }(); + + auto lse_acc_dram_window = + make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(), + lse_acc_dram_block_window_tmp.get_window_lengths(), + lse_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakeLSEaccDramTileDistribution()); + + // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). + auto lse_acc_tile = load_tile(lse_acc_dram_window); + store_tile(lse_acc_lds_write_window, lse_acc_tile); + block_sync_lds(); + + auto lse_accum = make_static_distributed_tensor( + Policy::template MakeLSEaccRegTileDistribution()); + + // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) + // this will extend the distributed tensor width so that each thread in wave have data to + // reduce. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_accum(i_j_idx) = lse_acc_lds(row, col); + } + else + { + lse_accum(i_j_idx) = -numeric::infinity(); + } + }); + }); + } + + // compute the logsumexp of the LSE along the split dimension. + const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto lse_max = block_tile_reduce( + lse_accum, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(lse_max, f_max, bool_constant{}); + + static const auto get_validated_m = [](LSEDataType raw_m) { + return raw_m == -numeric::infinity() ? type_convert(0.f) + : raw_m; + }; + + decltype(lse_accum) lse_exp; + { + constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + lse_exp(i_j_idx) = + ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); + }); + }); + } + + auto lse_sum = block_tile_reduce( + lse_exp, sequence<1>{}, f_sum, type_convert(0)); + block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); + + decltype(lse_max) lse_logsum; + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) + { + lse_logsum(i_idx) = numeric::infinity(); + } + else + { + lse_logsum(i_idx) = + ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); + } + }); + } + + // store the lse scales in shared memory. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_acc_lds(row, col) = + ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); + } + }); + }); + } + block_sync_lds(); + + if constexpr(kStoreLSE) + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_logsum(i_idx) == numeric::infinity()) + { + lse_logsum(i_idx) = -numeric::infinity(); + } + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); + } + + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + auto o_acc_dram_window = + make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), + o_acc_dram_block_window_tmp.get_window_lengths(), + o_acc_dram_block_window_tmp.get_window_origin(), + o_acc_dist); + auto o_acc = make_static_distributed_tensor(o_acc_dist); + clear_tile(o_acc); + + const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + + for(index_t i_split = 0; i_split < num_splits; ++i_split) + { + auto o_tile = load_tile(o_acc_dram_window); + { + constexpr auto spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + + const auto row = x_indices.at(number<0>{}); + + const LSEDataType lse_scale = lse_acc_lds(row, i_split); + o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); + }); + }); + } + + move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); + } + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window, + const OaccDramBlockWindow& o_acc_dram_block_window, + LSEDramBlockWindow& lse_dram_block_window, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + return operator()(lse_acc_dram_block_window, + o_acc_dram_block_window, + lse_dram_block_window, + identity{}, + identity{}, + num_splits, + max_seqlen_q, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp new file mode 100644 index 0000000000..2eb092f055 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() + { + using LSEDataType = remove_cvref_t; + return 16 / sizeof(LSEDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + return 16 / sizeof(OaccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEaccLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + + constexpr index_t NPerThread = 16 / sizeof(LSEDataType); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; + constexpr index_t TotalWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // 3d + padding, [kMaxSplits, kM0] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lse_acc_lds_block_desc; + } + + // 3d + padding, [kM0, kMaxSplits] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lse_acc_t_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kMPerBlock = Problem::kM0; + + constexpr index_t NThreads = get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / NThreads; + + constexpr index_t MThreads = kBlockSize / NThreads; + constexpr index_t MPerThread = kMPerBlock / MThreads; + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MThreads * MPerThread == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + { + using OaccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + + constexpr index_t N1 = 16 / sizeof(OaccDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = get_warp_size() / N0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp new file mode 100644 index 0000000000..a6d74b3885 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -0,0 +1,666 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaFwdSplitKVPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000..ae363a4978 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), + {0, 0}); + }, + number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse acc + if constexpr(kStoreLSE) + { + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000..6109fa5ab9 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000..338319ab32 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 1b72b60054..23b75f16ac 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -54,4 +54,69 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem +{ + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; +}; + +template +struct BlockFmhaSplitKVCombinePipelineProblem +{ + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + static constexpr index_t kHeadDimV = HeadDimV_; + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN1 = kN1_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr index_t kMaxSplits = Traits::kMaxSplits; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 973ffa9f82..a59431e39d 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -32,6 +32,50 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaFwdSplitKVTraits : TileFmhaTraits +{ + // determine if some split (length) is not divisible by tile size + static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; +}; + +template +struct TileFmhaFwdSplitKVCombineTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + + static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_); + static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0); + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + template From a32b1bc64794b8bd51f6924cb7d72dd1059803dd Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 26 Jun 2024 22:04:52 +0800 Subject: [PATCH 04/16] Replace hipDeviceSynchronize() by hipStreamSynchronize(stream) calls (#1359) --- include/ck_tile/host/timer.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp index e2baeaef7c..e5519643bf 100644 --- a/include/ck_tile/host/timer.hpp +++ b/include/ck_tile/host/timer.hpp @@ -27,7 +27,7 @@ struct gpu_timer CK_TILE_HOST void start(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); } @@ -51,15 +51,15 @@ struct gpu_timer struct cpu_timer { // torch.utils.benchmark.Timer(), there is a sync inside each timer callback - CK_TILE_HOST void start(const hipStream_t&) + CK_TILE_HOST void start(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); start_tick = std::chrono::high_resolution_clock::now(); } // torch.utils.benchmark.Timer(), there is a sync inside each timer callback - CK_TILE_HOST void stop(const hipStream_t&) + CK_TILE_HOST void stop(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); stop_tick = std::chrono::high_resolution_clock::now(); } // return in ms From 941d1f7ce0a7529e7e23039baff1da3064a5d10e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:33:34 -0700 Subject: [PATCH 05/16] Merging the gfx12 code into public repo. (#1362) --- CMakeLists.txt | 6 +- Jenkinsfile | 4 +- cmake/EnableCompilerWarnings.cmake | 2 +- example/01_gemm/gemm_wmma_fp16.cpp | 54 +- example/01_gemm/run_gemm_example.inc | 2 +- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 4 +- .../cross_attention_forward_wmma_fp16.cpp | 6 +- .../self_attention_forward_wmma_fp16.cpp | 6 +- example/CMakeLists.txt | 4 +- include/ck/ck.hpp | 14 +- include/ck/host_utility/device_prop.hpp | 5 + .../gpu/block/blockwise_gemm_wmma.hpp | 499 ++++++++++++++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 7 + ...d_contraction_multiple_d_wmma_cshuffle.hpp | 23 +- .../device_batched_gemm_multiple_d_dl.hpp | 7 +- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 9 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 4 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 2 +- .../device/impl/device_fpAintB_gemm_wmma.hpp | 2 +- .../gpu/device/impl/device_gemm_dl.hpp | 2 +- .../device/impl/device_gemm_multiple_d_dl.hpp | 7 +- .../device_gemm_multiple_d_wmma_cshuffle.hpp | 2 +- .../gpu/device/impl/device_gemm_wmma.hpp | 21 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_grouped_conv_bwd_weight_dl.hpp | 5 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 2 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 7 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 4 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_grouped_gemm_multiple_d_dl.hpp | 7 +- ...e_grouped_query_attention_forward_wmma.hpp | 5 +- ...ice_multi_query_attention_forward_wmma.hpp | 5 +- ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 20 +- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 22 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 44 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 42 +- .../gpu/grid/gridwise_tensor_rearrange.hpp | 5 +- .../threadwise_tensor_slice_transfer.hpp | 109 +++- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 147 +++++- include/ck/utility/amd_wmma.hpp | 82 +++ include/ck/utility/data_type.hpp | 2 +- include/ck/utility/synchronization.hpp | 17 + include/ck_tile/core/config.hpp | 5 +- .../gpu/CMakeLists.txt | 8 +- profiler/src/CMakeLists.txt | 4 +- test/CMakeLists.txt | 4 +- .../test_grouped_convnd_bwd_weight.cpp | 2 +- test/wmma_op/wmma_op_util.hpp | 16 + 49 files changed, 1099 insertions(+), 163 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e8626b2cb9..b27e6ab4fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,7 +117,7 @@ else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") endif() if(GPU_ARCH MATCHES "gfx90") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") @@ -127,8 +127,10 @@ else() rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + elseif(GPU_ARCH MATCHES "gfx12") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") endif() set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 855fe8dff9..67e9b2fcb3 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none - triggers { - parameterizedCron(CRON_SETTINGS) - } options { parallelsAlwaysFailFast() } diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..93fd306e98 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 8c52e4f7d7..f8afe8d6db 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle - < ALayout, - BLayout, - CLayout, - ADataType, + < ALayout, + BLayout, + CLayout, + ADataType, BDataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CElementOp, - GemmDefault, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, 1, // Prefetch stage 128, // BlockSize 64, // MPerBlock 128, // NPerBlock 64, // KPerBlock - 8, // K1 + 2, // K1 16, // MPerWmma 16, // NPerWmma 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, 1, // C shuffle (M Repeat) Per store 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, + S<1, 32, 1, 4>, 8>; // clang-format on diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index b04e4e53a8..cb15186c3b 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; case 4: - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); break; case 5: diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index ab19f819e8..be47665a26 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 2bbf430c4e..f556be887f 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = 2, 4, 4, - true, + false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, - true, + false, 1, 1, S<1, 64, 1, 2>, diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index 4c92c5497f..fac19f8b5a 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index 8e037272b8..d463cc8716 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index fd9f5cd89d..c9781637d6 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 32eea551f5..9528a30b4b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code @@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 @@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_MFMA_GFX940 #endif -// WMMA instruction -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_USE_AMD_WMMA -#elif defined(__gfx11__) // for GPU code -#define CK_USE_AMD_WMMA -#endif - // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 116bb3ea02..83af2efe88 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -84,4 +84,9 @@ inline bool is_gfx11_supported() ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 873539f8b1..3ea19da741 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -13,6 +13,504 @@ namespace ck { +#ifdef __gfx12__ +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; + + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else template ::type a_thread_copy_; typename BThreadCopySelector::type b_thread_copy_; }; +#endif } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e2296a55f7..d3f6344c27 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // sync point. if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("s_barrier" ::); +#endif __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index a157595593..ab3f3856aa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; - static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } else { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + if(!(arg.a_kz_stride_ == 1)) { - printf("DeviceOp: Vector Access A-k check failure\n"); - return false; + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 8fd14afc0c..1b487502f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -70,8 +70,9 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index f1bc6a2261..f0f89f1d1b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && - ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && - std::is_same::value) + if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index b84e181306..1edae33be3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl { // check device if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || - ck::is_gfx11_supported())) + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index bf96324d00..553143e286 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index b1784b3858..eb0fb55f5d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 93ab8a7e1d..a7cc546f53 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; - static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; - static constexpr auto AEnableLds_auto = - (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; static constexpr auto BEnableLds_auto = - (MWaves == 1 && is_same::value) ? false : true; + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 6f74838fba..6bb5d431c9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 86091aeba9..cc26936fef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -48,8 +48,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 211185dfb0..5738be0fb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index ce86ec54e5..c3fe54b075 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -90,8 +90,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || - ck::is_gfx103_supported() || ck::is_gfx11_supported())) + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 5c9d63e2b0..c6b84b613c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -107,7 +107,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx11__)) + defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index ac392cddc4..060a16d1e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index 4e14ed3a51..cc88c1a104 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -60,7 +60,7 @@ __global__ void bool input_permute, bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -165,6 +165,7 @@ __global__ void ignore = O; ignore = G0; ignore = G1; + ignore = alpha; ignore = input_permute; ignore = output_permute; #endif // end of if (defined(__gfx11__)) @@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 16717ff819..1754e07e6a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B0EnableLds) { // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B1EnableLds) { // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 - constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); - constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else constexpr auto B_LRow = I1; +#endif return transform_tensor_descriptor( B1BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 499eb7eb01..21dac6f9e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,7 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 49a6dc3b0f..b3b057c80a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -147,7 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; @@ -237,7 +237,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; GridwiseOp::template Run(p_a_grid, @@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } @@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 8e4117593c..4458b9356d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -45,7 +45,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma } else { + + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - using DefaultBlock2CTileMap = - remove_cvref_t; - struct SharedMemTrait { // LDS allocation for A and B: be careful of alignment @@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma b_block_space_size_aligned * sizeof(BDataType)); }; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + template __device__ static void Run(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 6772524e0a..1740749907 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -35,8 +35,9 @@ __global__ void const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) GridwiseTensorRearrangeKernel::Run(in_grid_desc, p_in_global, out_grid_desc, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bcce930fc7..d7a6a36244 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; -// Specilized for WMMA +// Specilized for WMMA-Navi3 // A single Wave32 is composed by double row // Data exchange allowed between these two rows // This RowLane Dst buf will be filled from two Src buf @@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ElementwiseOperation element_op_{}; }; +// Specilized for WMMA-Navi4 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 565195f53e..9a9ebf5595 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,12 +11,17 @@ namespace ck { enum struct WmmaInstr { + // gfx11 wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, wmma_i32_16x16x16_iu8, - wmma_i32_16x16x16_iu4 + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, }; /* @@ -279,6 +284,122 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + template static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_f16; +#endif } template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif } template <> @@ -320,8 +449,13 @@ struct WmmaSelector template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> static constexpr auto GetWmma() @@ -502,6 +636,9 @@ struct WmmaGemm __device__ static auto GetSubGroupId() { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } @@ -516,12 +653,20 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif } __device__ static CIndex GetBeginOfThreadBlk() diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 1bb0140f3e..322a0f94bb 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 93a1edefb6..4df14c6211 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -203,7 +203,7 @@ struct vector_type } }; -int static err = 0; +__device__ int static err = 0; template struct vector_type { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 4fe5e39504..d6b6eac26c 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -10,12 +10,20 @@ namespace ck { __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ // s_barrier \ // " ::); __builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_barrier(); +#endif #else __syncthreads(); #endif @@ -23,11 +31,20 @@ __device__ void block_sync_lds() __device__ void block_sync_lds_direct_load() { +#ifdef __gfx12__ + asm volatile("\ + s_wait_vmcnt 0x0 \n \ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } __device__ void s_nop() diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 344343d931..83637e18e4 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -17,6 +17,9 @@ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" @@ -155,7 +158,7 @@ #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) // for GPU code +#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 05b8c035c4..1bcc0f802b 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index fa0eb6f887..5262ca33a6 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9") endif() -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() @@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 49b67992b1..66b4d3d27d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 5ef0730668..aee80cb2cb 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { // on gfx11 only support for 3d is implemented if constexpr(NDimSpatial{} != 3) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 49782bce6e..d9ec94771a 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif for(int ele = 0; ele < 16; ++ele) { @@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif // sync threads, similar to mma_sync // __syncthreads(); From ed21948bcd8dd84a7d7cf625c9d8f472ea8f322f Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 27 Jun 2024 11:30:32 +0200 Subject: [PATCH 06/16] Add structural sparsity gemm instruction tests (#1309) * first version of smfmac test * add reviewer comments * add reviewer suggestions --- include/ck/utility/amd_smfmac.hpp | 69 ++++ library/include/ck/library/utility/fill.hpp | 37 +- test/CMakeLists.txt | 3 + test/smfmac_op/CMakeLists.txt | 2 + test/smfmac_op/smfmac_op.cpp | 82 +++++ test/smfmac_op/smfmac_op_util.hpp | 361 ++++++++++++++++++++ test/smfmac_op/smfmac_op_xdl.cpp | 89 +++++ 7 files changed, 642 insertions(+), 1 deletion(-) create mode 100644 include/ck/utility/amd_smfmac.hpp create mode 100644 test/smfmac_op/CMakeLists.txt create mode 100644 test/smfmac_op/smfmac_op.cpp create mode 100644 test/smfmac_op/smfmac_op_util.hpp create mode 100644 test/smfmac_op/smfmac_op_xdl.cpp diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp new file mode 100644 index 0000000000..2342930859 --- /dev/null +++ b/include/ck/utility/amd_smfmac.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#pragma once + +namespace ck { + +template +struct intrin_smfmac_f32_16x16x32f16; + +template <> +struct intrin_smfmac_f32_16x16x32f16<16, 16> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_16x16x32bf16; + +template <> +struct intrin_smfmac_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_32x32x16f16; + +template <> +struct intrin_smfmac_f32_32x32x16f16<32, 32> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_32x32x16bf16; + +template <> +struct intrin_smfmac_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +} // namespace ck diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 4e075df43b..3336041354 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -133,5 +133,40 @@ struct FillConstant } }; +template +struct TransformIntoStructuralSparsity +{ + // clang-format off + static constexpr T valid_sequences[] = { + 0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + }; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::for_each(first, last, [=, idx = 0](T& elem) mutable { + auto tmp_idx = idx; + idx += 1; + return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + } // namespace utils } // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 66b4d3d27d..7ee37d211d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -209,4 +209,7 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() +if(GPU_TARGETS MATCHES "gfx942") + add_subdirectory(smfmac_op) +endif() add_subdirectory(position_embedding) diff --git a/test/smfmac_op/CMakeLists.txt b/test/smfmac_op/CMakeLists.txt new file mode 100644 index 0000000000..4ffc423f54 --- /dev/null +++ b/test/smfmac_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_smfmac_op smfmac_op_xdl.cpp) +target_link_libraries(test_smfmac_op PRIVATE utility) diff --git a/test/smfmac_op/smfmac_op.cpp b/test/smfmac_op/smfmac_op.cpp new file mode 100644 index 0000000000..de4f9414af --- /dev/null +++ b/test/smfmac_op/smfmac_op.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +template +bool run_test() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool pass = true; + + const auto matmul_default = ck::smfmac_op_util::matmul; + + const auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, 1, 1>{}([&](auto i) { + pass &= + ck::smfmac_op_util::TestSmfmac{}>( + smfmac_kernel_container)), + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + return pass; +} +int main(int, char*[]) +{ + bool pass = true; + // clang-format off + // | Src1Type| Src1VecSize| Src2Type| Src2VecSize| DstType| DstVecSize| GPUAccType| CPUAccType| M| N| K| + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 4, float, float,16,16,32>(); + pass &= run_test(); + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 16, float, float,32,32,16>(); + pass &= run_test(); + // clang-format on + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass; +} diff --git a/test/smfmac_op/smfmac_op_util.hpp b/test/smfmac_op/smfmac_op_util.hpp new file mode 100644 index 0000000000..44122c551d --- /dev/null +++ b/test/smfmac_op/smfmac_op_util.hpp @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/amd_smfmac.hpp" +#include "ck/library/utility/fill.hpp" + +namespace ck { +namespace smfmac_op_util { + +template +__device__ void +builtin_smfmac_naive_selector(const src_vec1&, const src_vec2&, const int32_t&, acc_vec&) +{ +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32f16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32bf16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + half4_t, + half8_t, + StaticBufferTupleOfVector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16f16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + bhalf4_t, + bhalf8_t, + StaticBufferTupleOfVector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16bf16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous +// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is +// stored in idx register to allow selection of corresponding B matrix elements for multiplication. +// Currently smfmac instructions support only A matrix as sparse +template +__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c) +{ + __shared__ src1_t a_shared[M * K]; + __shared__ src2_t b_shared[K * N]; + const int lane = threadIdx.x; + // smfmac's A part is storing only non-zero elements in 2VGPRs + // smfmac's B part is storing all elements in 4VGPRs + using src1_vec = typename vector_type::type; + using src1_full_vec = typename vector_type::type; + using src2_vec = typename vector_type::type; + src1_vec a_frag = {}; + src2_vec b_frag = {}; + + src1_full_vec a_temp = {}; + src2_vec b_temp = {}; + // initialize c fragment to 0 + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + for(int i = 0; i < 8; ++i) + { + a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K + } + + for(int i = 0; i < 8; ++i) + { + b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N + } + + __syncthreads(); + + for(int i = 0; i < 8; ++i) + { + a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i]; + } + for(int i = 0; i < 8; ++i) + { + b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i]; + } + + __syncthreads(); + + // Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements. + // It starts with last two elements of every 4 elements subgroup set as non-zero + int32_t idx = 0b11101110; + // Bit masks are for zeroing 0-3rd position of idx + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + + src1_t curr_val; + int32_t a_pos = 0; + for(int j = 0; j < 2; ++j) + { + a_pos = j * 2; + for(int i = 0; i < 4; ++i) + { + curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i]; + if(curr_val != 0.0f) + { + idx &= ~bit_clear_masks[a_pos]; + idx |= (i % 4) << 2 * a_pos; + a_frag[a_pos] = curr_val; + a_pos++; + } + } + } + + for(int i = 0; i < 8; ++i) + { + b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)]; + } + + builtin_smfmac_naive_selector(a_frag, b_frag, idx, c_thread_buf_); + __syncthreads(); + + // store results from unpacked c_thread_buf_ output + if constexpr(K == 32) + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[(4 * (lane / 16) + i) * N + lane % 16] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } + else + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } +} + +struct GemmParams +{ + GemmParams() : M(16), N(16), K(32), StrideA(32), StrideB(16), StrideC(16), alpha(1), beta(0) {} + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestSmfmac +{ + auto PrepareGemmTensor(const ck::smfmac_op_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_n_k, BDataType{}); + ck::utils::TransformIntoStructuralSparsity{}(a_m_k); + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceSmfmac& smfmac_kernel) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + ck::smfmac_op_util::GemmParams params; + params.M = M; + params.N = N; + params.K = K; + params.StrideA = K; // M K + params.StrideB = N; // K N + params.StrideC = N; // M N + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::smfmac_op_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::smfmac_op_util::RunDeviceGEMM(smfmac_kernel, a, b, c_device); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace smfmac_op_util +} // namespace ck diff --git a/test/smfmac_op/smfmac_op_xdl.cpp b/test/smfmac_op/smfmac_op_xdl.cpp new file mode 100644 index 0000000000..292fd259ea --- /dev/null +++ b/test/smfmac_op/smfmac_op_xdl.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestSmfmac : public ::testing::Test +{ + protected: + using Src1Type = std::tuple_element_t<0, Tuple>; + static constexpr ck::index_t Src1VecSize = std::tuple_element_t<1, Tuple>{}.value; + using Src2Type = std::tuple_element_t<2, Tuple>; + static constexpr ck::index_t Src2VecSize = std::tuple_element_t<3, Tuple>{}.value; + using DstType = std::tuple_element_t<4, Tuple>; + static constexpr ck::index_t AccVecSize = std::tuple_element_t<5, Tuple>{}.value; + using GPUAccType = std::tuple_element_t<6, Tuple>; + using CPUAccType = std::tuple_element_t<7, Tuple>; + static constexpr ck::index_t M = std::tuple_element_t<8, Tuple>{}.value; + static constexpr ck::index_t N = std::tuple_element_t<9, Tuple>{}.value; + static constexpr ck::index_t K = std::tuple_element_t<10, Tuple>{}.value; + + void Run() + { + bool pass = true; + constexpr auto matmul_default = ck::smfmac_op_util::matmul; + + constexpr auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + pass &= ck::smfmac_op_util::TestSmfmac< + std::tuple_element_t, + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + EXPECT_TRUE(pass); + } +}; + +template +using I = ck::Number; + +using KernelTypes = + ::testing::Types, F16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, BF16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, F16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>, + std::tuple, BF16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>>; + +TYPED_TEST_SUITE(TestSmfmac, KernelTypes); +TYPED_TEST(TestSmfmac, TestSmfmacFP16BF16) { this->Run(); } From 3bb0fe6c7e321bb44b1ec247d34e93607f61162c Mon Sep 17 00:00:00 2001 From: alexxu-amd <159800977+alexxu-amd@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:57:58 -0400 Subject: [PATCH 07/16] remove PR trigger for now due to high cost (#1329) --- .azuredevops/rocm-ci.yml | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 8c52856755..4161c2d5a4 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -23,20 +23,7 @@ trigger: - Jenkinsfile - LICENSE -pr: - autoCancel: true - branches: - include: - - develop - paths: - exclude: - - .github - - docs - - '.*.y*ml' - - '*.md' - - Jenkinsfile - - LICENSE - drafts: false +pr: none jobs: - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo From fafa567b3cd00993af70b78114429cb6fd4723a0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:09:00 -0700 Subject: [PATCH 08/16] Adding a private docker for ROCm6.2 release candidate. (#1365) * add private docker for rocm6.2_rc1 * update dockerfile --- Dockerfile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0d3807f4d4..0c98188b99 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,11 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.2" ]; then \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc2" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1736298; \ + elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.2-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.2-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-8 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1794148; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" From 2525864fdae349ce8e839faec8971be915116b0f Mon Sep 17 00:00:00 2001 From: Ruturaj Vaidya Date: Thu, 27 Jun 2024 14:34:25 -0500 Subject: [PATCH 09/16] Update CMakeLists.txt (#1364) It is a good practice to check if the file CMakeLists.txt is in fact in the directory. --- example/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c9781637d6..87c5a89f8a 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}") + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) ENDIF() ENDFOREACH() From 614ebd050ad36aa78dd84342e838b53fda7a9996 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:14:36 -0700 Subject: [PATCH 10/16] Bump rocm-docs-core from 1.4.0 to 1.4.1 in /docs/sphinx (#1367) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.4.0 to 1.4.1. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.4.0...v1.4.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 9c9706c664..6605380a51 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.4.0 +rocm-docs-core==1.4.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 313c300267..a3566090ef 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.4.0 +rocm-docs-core==1.4.1 # via -r requirements.in six==1.16.0 # via From 497ccb872b6a9a921c01df4dca49dac7cb242c72 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 28 Jun 2024 06:50:46 -0700 Subject: [PATCH 11/16] fix the optional ckProfiler grouped_gemm arguments (#1368) --- profiler/src/profile_grouped_gemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 25203d7b6c..fbf44d720f 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[]) int n_iter = 10; if(argc == 17) { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); } #ifdef CK_ENABLE_FP16 From 959073842c0db839d45d565eb260fd018c996ce4 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Wed, 3 Jul 2024 23:34:38 -0700 Subject: [PATCH 12/16] Fix issue with multiple targets and remove smfmac tests from unsupported test targets (#1372) --- Jenkinsfile | 4 +-- .../25_wrapper/wrapper_basic_gemm.cpp | 17 +++++++++-- .../25_wrapper/wrapper_optimized_gemm.cpp | 16 ++++++++-- .../gemm_bilinear_wmma_fp16.cpp | 9 ++++++ .../gemm_bilinear_wmma_int8.cpp | 9 ++++++ ...ouped_conv_fwd_bias_relu_add_wmma_fp16.cpp | 13 ++++++++- ...ouped_conv_fwd_bias_relu_add_wmma_int8.cpp | 13 ++++++++- ...e_scale_softmax_gemm_permute_wmma_fp16.cpp | 13 ++++++++- ...m_scale_softmax_gemm_permute_wmma_fp16.cpp | 13 ++++++++- .../cross_attention_forward_wmma_fp16.cpp | 13 ++++++++- ...uped_query_attention_forward_wmma_fp16.cpp | 13 ++++++++- ...ulti_query_attention_forward_wmma_fp16.cpp | 13 ++++++++- .../self_attention_forward_wmma_fp16.cpp | 13 ++++++++- .../grouped_conv_bwd_data_wmma_fp16.cpp | 13 ++++++++- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 20 ++++++------- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 29 ++++++++++--------- include/ck/utility/amd_smfmac.hpp | 28 ++++++++++++++++++ test/CMakeLists.txt | 6 +++- test/grouped_convnd_bwd_data/CMakeLists.txt | 8 ++--- ...grouped_convnd_bwd_data_interface_wmma.cpp | 8 +++++ test/grouped_convnd_bwd_weight/CMakeLists.txt | 8 ++--- ...ouped_convnd_bwd_weight_interface_wmma.cpp | 8 +++++ test/grouped_convnd_fwd/CMakeLists.txt | 2 +- test/wmma_op/wmma_op_util.hpp | 4 ++- 24 files changed, 243 insertions(+), 50 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 67e9b2fcb3..8809fc50c9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -886,10 +886,10 @@ pipeline { } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a" \ + -DGPU_TARGETS="gfx1100;gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp index 59c5c243ce..23245dd188 100644 --- a/client_example/25_wrapper/wrapper_basic_gemm.cpp +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -7,19 +7,23 @@ #include #include +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/host_tensor.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), @@ -213,3 +225,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp index b6294c2393..31e20342df 100644 --- a/client_example/25_wrapper/wrapper_optimized_gemm.cpp +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -7,18 +7,21 @@ #include #include -#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), @@ -305,3 +316,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index d1b820da7b..18731e810e 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index aca136f801..87812369bd 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 039d250299..ff873d26bc 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = FP16; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp index 793324970e..662a6f611b 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = I8; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp index 2c7bacfc4e..69ab5c5c0b 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp index d9ab645ee9..f5cedb14c9 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index fac19f8b5a..41c6dff2df 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_cross_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp index 12dcfcc36d..955c25f0d1 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -28,6 +28,7 @@ Example is GQA-4 #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -299,4 +300,14 @@ using ReferenceGemm1Instance = #include "run_grouped_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp index 694a320a45..112be07c49 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_ #include "run_multi_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index d463cc8716..9ec1bc933f 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_self_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp index 5baa521501..3e3ae7edbd 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -3,6 +3,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" #include "common.hpp" +#include "ck/host_utility/device_prop.hpp" using OutDataType = FP16; using WeiDataType = FP16; @@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat #include "run_grouped_conv_bwd_data_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index f4f496fc10..d9e300b737 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -47,12 +47,12 @@ __global__ void #endif kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -103,12 +103,12 @@ __global__ void #endif kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 415ae3d496..a4d4a01a01 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -69,14 +69,15 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n, - const index_t groups_count) + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group @@ -132,13 +133,13 @@ __global__ void #endif kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n, - const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp index 2342930859..abb8d9f5ef 100644 --- a/include/ck/utility/amd_smfmac.hpp +++ b/include/ck/utility/amd_smfmac.hpp @@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16> __device__ static void Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16> __device__ static void Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32> __device__ static void Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32> __device__ static void Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7ee37d211d..3b121fc309 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -209,7 +213,7 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() -if(GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 add_subdirectory(smfmac_op) endif() add_subdirectory(position_embedding) diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 3507989bae..8edb715200 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp index c0429c6d09..fbb6ffc6f5 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 54b514e7a1..313b5ba4ca 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility) endif() add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp index 1dcb8f866d..2e2f5332ae 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1eba91382e..f611e66243 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,6 +1,6 @@ if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) - if(GPU_TARGETS MATCHES "gfx11") + if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) else() target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index d9ec94771a..3e511ab5bf 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/utility/amd_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace wmma_op_util { @@ -373,7 +374,8 @@ struct TestWmma a, b, c_host, a_element_op, b_element_op, c_element_op); // Act - bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + bool is_supported = ck::is_gfx11_supported() && + ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); if(is_supported) { From eaa870a1ab91fedfc614609fbb6e843ae5231dca Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Jul 2024 12:00:14 +0200 Subject: [PATCH 13/16] Add structural sparsity xdlops (#1363) * Implemented smfmac xdlops * add reviewer comments --- .../gpu/warp/smfmac_xdlops_gemm.hpp | 409 ++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp new file mode 100644 index 0000000000..33c07f34f7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -0,0 +1,409 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_smfmac.hpp" + +namespace ck { + +enum struct SmfmacInstr +{ + smfmac_f32_16x16x32f16 = 0, + smfmac_f32_32x32x16f16, + smfmac_f32_16x16x32bf16, + smfmac_f32_32x32x16bf16, +}; + +template +struct smfmac_type; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32bf16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16bf16::Run(a, b, idx, reg_c); + } +}; + +template +struct SmfmacSelector +{ + template + static constexpr auto GetSmfmac(); + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32bf16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16bf16; + } + + static constexpr auto selected_smfmac = + smfmac_type()>{}; + + __host__ __device__ constexpr SmfmacSelector() + { + static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk == + selected_smfmac.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks == + selected_smfmac.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks || + selected_smfmac.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size == + selected_smfmac.m_per_blk * selected_smfmac.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_smfmac.is_k_reduction || + (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) * + selected_smfmac.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; } +}; + +template +struct SparseXdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks); + } + + __host__ __device__ constexpr SparseXdlopsGemm() + { + static_assert(NPerXdlops == 16 || NPerXdlops == 32, + "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(MPerXdlops == 16 || MPerXdlops == 32, + "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk, + smfmac_instr.num_input_blks, + smfmac_instr.group_size)), + make_pass_through_transform(smfmac_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / smfmac_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; } + + template + __device__ void + Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value, + "base base_type must be half or bfloat16!"); + + static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size; + + return CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto smfmac = + SmfmacSelector{}; + + static constexpr auto smfmac_instr = smfmac.selected_smfmac; + + static constexpr auto KPerXdlops = smfmac.GetKPerXdlops(); + static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck From 75e622f02f964523dc2c60125904ce4018fae0f2 Mon Sep 17 00:00:00 2001 From: Harisankar Sadasivan <135730918+hsadasiv@users.noreply.github.com> Date: Fri, 5 Jul 2024 21:40:30 -0700 Subject: [PATCH 14/16] Universal streamk with atomics (#1360) * universal streamk with atomics with ckprofiler support. grid_size and streamk strategy are tunable. grid_size of -1 leads to #WGs = maximum occupancy X num_CUs. implementation supports many different streamk policies: 1-tile, 2-tile, 3-tile and 4-tile. streamk strategy of -1 leads to default streamk policy (4-tile). * Update README.md * fixing clang-format issues * removed conflicts in struct members between streamk and universal streamk * corrected arg parsing for streamk and universal streamk * added stream-k policies for 3 tile and 4 tile * fixed argument type issue with parsing cmd args * changes suggested in PR review are made- removing comments and correcting copyright * file permissions updated * added default value support for grid_size and streamk-policy selection set to -1 * print messages for arguments * print messages for arguments * print messages for arguments1 --- .pre-commit-config.yaml | 0 example/01_gemm/CMakeLists.txt | 2 + example/01_gemm/README.md | 18 + example/01_gemm/common.hpp | 69 +- example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp | 48 + .../01_gemm/run_gemm_example_streamk_v2.inc | 298 +++ .../gpu/device/device_gemm_streamk_v2.hpp | 44 + .../device_gemm_xdl_cshuffle_streamk_v3.hpp | 556 +++++ .../gpu/grid/block_to_ctile_map.hpp | 322 +++ .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 2010 +++++++++++++++++ .../gpu/gemm_universal_streamk.hpp | 337 +++ .../gpu/gemm_universal_streamk/CMakeLists.txt | 26 + ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 91 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 31 + ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 98 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 31 + .../gemm_universal_streamk/CMakeLists.txt | 26 + ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 91 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 31 + ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 98 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 31 + .../profile_gemm_universal_streamk_impl.hpp | 332 +++ profiler/src/CMakeLists.txt | 2 + .../src/profile_gemm_universal_streamk.cpp | 156 ++ script/check_copyright_year.sh | 0 61 files changed, 5846 insertions(+), 2 deletions(-) mode change 100644 => 100755 .pre-commit-config.yaml create mode 100644 example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp create mode 100644 example/01_gemm/run_gemm_example_streamk_v2.inc create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100644 => 100755 profiler/src/CMakeLists.txt create mode 100644 profiler/src/profile_gemm_universal_streamk.cpp mode change 100755 => 100644 script/check_copyright_year.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 23683de449..98fd9c6b77 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) +add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3) add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md index a09e69255f..5edec1f043 100644 --- a/example/01_gemm/README.md +++ b/example/01_gemm/README.md @@ -7,3 +7,21 @@ #arg3: run kernel # of times (>1) ./bin/example_gemm_xdl 0 1 5 ``` + +# Instructions for ```example_gemm_xdl_fp16_streamk_v3``` + +## Run ```example_gemm_xdl_fp16_streamk_v3``` +```bash +arg1: verification (0=no, 1=yes) +arg2: initialization (0=no init, 1=integer value, 2=decimal value) +arg3: time kernel (0=no, 1=yes) +arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK) +arg11: Grid_size(-1 for max occupancy) +bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1 +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1} +Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2 +``` diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index ef87d9c2fc..3d8f4565cb 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,6 +45,19 @@ struct ProblemSizeStreamK final ck::index_t NumSKBlocks = -1; }; +struct ProblemSizeStreamK_universal final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK +}; struct ProblemSizeSplitK final { @@ -123,6 +136,57 @@ bool parse_cmd_args(int argc, return true; } +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK_universal& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.Streamk_sel = std::stoi(argv[10]); + problem_size.Grid_size = std::stoi(argv[11]); + } + } + else + { + std::cerr + << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + return false; + } + + return true; +} + template <> bool parse_cmd_args(int argc, char* argv[], @@ -165,7 +229,8 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: NumSKBlocks(optional)" << std::endl; + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp new file mode 100644 index 0000000000..5b163962b9 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc new file mode 100644 index 0000000000..6679f95157 --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto Grid_size = problem_size.Grid_size; + auto Streamk_sel = problem_size.Streamk_sel; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) { + if(streamk_sel == -1) + { + return static_cast(4); + } + else + return static_cast(streamk_sel); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Streamk_sel = f_get_default_streamk_policy(Streamk_sel); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2_Streamk_Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + Streamk_sel, + Grid_size, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_universal_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK_universal problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp new file mode 100644 index 0000000000..1a4d684f14 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Streamk_V2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t Streamk_sel, + ck::index_t Grid_size, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000..452063156e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + float ave_time = 0; + + index_t k_grain = KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + hipGetErrorString(hipMemsetAsync( + arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); + const auto Run = [&](const auto& kernel) { + dim3 grid_dim; + if(arg.Grid_size < 0) + { + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, 0); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + arg.Grid_size = num_cu * occupancy; + grid_dim = arg.Grid_size; + } + else + grid_dim = arg.Grid_size; + + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_); + } + else + { + + ave_time = launch_and_time_kernel( + stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + + return Argument{ + p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +struct BlockToCTileMap_GemmStreamK_v2 +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + mutable uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv equiv_tiles_big; // for reduction + MDiv equiv_tiles_little; // for reduction + + // prefer construct on host + __host__ __device__ BlockToCTileMap_GemmStreamK_v2( + uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) + { + // total output tiles + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + + // default to regular DP GEMM if sk blocks == 0 + if(streamk_sel == 0) + { + sk_num_blocks = 0; + dp_tiles = num_tiles; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + // 2-tile sk + DP GEMM + else + { + + // check if there's enough work for DP+ stream-k + bool bigEnough = num_tiles > grid_size; + // select between stream-k strategies + uint32_t sk_tiles = 0; + if(streamk_sel == 1) // 1 tile stream-k + { + sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 2) // 2-tile stream-k + { + sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 3) // 3-tile stream-k + { + sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) + : num_tiles; + } + else if(streamk_sel == 4) // 4-tile stream-k + { + sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) + : num_tiles; + } + sk_num_blocks = sk_tiles; + // remaining tiles are DP tiles + dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; + + sk_total_iters = k_iters_per_tile.get() * sk_tiles; + + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = sk_num_blocks; + } + + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + // using multiple blocks for parallel reduction + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ index_t get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + return reduction_start_block_idx + get_sk_tiles(); + } + else + return reduction_start_block_idx; + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& equiv_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000..ff10215353 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,2010 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_streamk_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + Streamk_sel{Streamk_sel_}, + Grid_size{Grid_size_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, 1)}, + KPadded{CalculateKPadded(K_, 1)}, + AK0{CalculateAK0Padded(K_, 1)}, + BK0{CalculateBK0Padded(K_, 1)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel + << ", Grid size:" << Grid_size << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t Streamk_sel; + mutable index_t Grid_size; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K) + { + if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead * problem.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead * problem.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead; + } + + if(kbatch_id < static_cast(problem.KBatch - 1)) + { + problem.K = problem.KRead; + } + else + { + problem.K = orig_K - problem.KRead * (problem.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + + if(karg.K <= 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(is_same, bhalf_t>::value) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk( + problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + { + + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } + } +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp new file mode 100644 index 0000000000..19fa6c209f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); +#endif +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemm_Streamk_V2; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 0000000000..2a930ab9ae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..6e8d5c798b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..6adcb8f4f4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..631ae6872f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..2c49773a65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000..39d54fb885 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..8ee50d63cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..d31e0819a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..fe19f35e53 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..6c1873b373 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..ffd53f4069 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..094b8f92f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..e00c1733e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..546f909b3c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..d91de96be3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..c70678b449 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000..5410a0cc25 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..4ae7329f98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..4fc5458a96 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..7369f87a57 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..45425a41a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..3b5ac0366f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..53aa011a75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 0000000000..2a930ab9ae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..6e8d5c798b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..6adcb8f4f4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..631ae6872f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..2c49773a65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000..39d54fb885 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..8ee50d63cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..d31e0819a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..fe19f35e53 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..6c1873b373 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..ffd53f4069 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..094b8f92f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..e00c1733e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..546f909b3c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..d91de96be3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000..c70678b449 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000..5410a0cc25 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..4ae7329f98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..4fc5458a96 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..7369f87a57 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..45425a41a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..3b5ac0366f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..53aa011a75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp new file mode 100644 index 0000000000..72194e8e61 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_universal_streamk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int Streamk_sel, + int Grid_size, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemm_Streamk_V2; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_grid_size = 0; + float best_streamk_sel = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + std::vector streamk_sel_list = { + 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, + // 2:2-tile Stream-K + DP + + if(Grid_size == -1) + { + grid_size_list = {Grid_size}; + } + if(Streamk_sel != -1) + { + streamk_sel_list = {Streamk_sel}; + } + for(std::size_t j = 0; j < streamk_sel_list.size(); j++) + { + for(std::size_t i = 0; i < grid_size_list.size(); i++) + { + auto grid_size_curr = grid_size_list[i]; + index_t streamk_sel_curr = streamk_sel_list[j]; + printf("streamk_sel_curr=%0d\n", streamk_sel_curr); + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel_curr, + grid_size_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", Grid_size " + << grid_size_curr << ", streamk selection strategy" + << streamk_sel_curr << std::endl; + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_grid_size = grid_size_curr; + best_streamk_sel = streamk_sel_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC + << " Grid_size = " << best_grid_size + << " Stream-K selection strategy = " << best_streamk_sel << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt old mode 100644 new mode 100755 index 5262ca33a6..c2a9769727 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -52,6 +52,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) @@ -120,6 +121,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp new file mode 100644 index 0000000000..cd3f5787d6 --- /dev/null +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_universal_streamk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 + F16_F16_F16_F8, // 6 +}; + +#define OP_NAME "gemm_universal_streamk" +#define OP_DESC "Universal Streamk GEMM" + +int profile_gemm_universal_streamk(int argc, char* argv[]) +{ + if(argc != 16 && argc != 19) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " + "comp f8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: Stream-k select strategy 0: all DP, 1: 1-tile SK, 2: 2-tile SK\n"); + printf("arg15: Grid-size, -1 for max persistent kernel occupancy\n"); + printf("optional:\n"); + printf("arg16: number of warm-up cycles (default 1)\n"); + printf("arg17: number of iterations (default 10)\n"); + printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int Streamk_sel = std::stoi(argv[14]); + const int Grid_size = std::stoi(argv[15]); + + int n_warmup = 20; + int n_iter = 50; + uint64_t rotating = 0; + if(argc == 19) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + rotating = std::stoull(argv[18]) * 1024 * 1024; + } + + using F32 = float; + using F16 = ck::half_t; + // using BF16 = ck::bhalf_t; + // using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_universal_streamk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + Streamk_sel, + Grid_size, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal_streamk); diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh old mode 100755 new mode 100644 From efd18fa8877db5ec2cfc22d17f48f8e213696213 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 8 Jul 2024 06:27:44 +0000 Subject: [PATCH 15/16] Conditionally add fwd_splitkv API in fmha_fwd example --- example/ck_tile/01_fmha/CMakeLists.txt | 33 +++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e30e9e793c..241213a6ec 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,27 @@ -# generate a list of kernels, but not actually emit files at config stage +# validate user-specified fmha_fwd API list +set(EXAMPLE_FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv") +set(EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${EXAMPLE_FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(EXAMPLE_FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(EXAMPLE_FMHA_FWD_ENABLE_APIS ${EXAMPLE_FMHA_FWD_KNOWN_APIS}) +endif() + +foreach(api ${EXAMPLE_FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST EXAMPLE_FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${EXAMPLE_FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd") +endif() + +string(REPLACE ";" "," EXAMPLE_FMHA_FWD_APIS "${EXAMPLE_FMHA_FWD_ENABLE_APIS}") +# generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api ${EXAMPLE_FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( @@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api ${EXAMPLE_FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( @@ -61,6 +81,13 @@ else() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() +# conditionally enable call to the fwd_splitkv API in fmha_fwd example +if ("fwd_splitkv" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) From 82f3b3d0a0aea8a8e5c9feafdbe081eabd7e81b6 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 8 Jul 2024 06:40:18 +0000 Subject: [PATCH 16/16] Conditionally add call to fmha_fwd_splitkv() --- example/ck_tile/01_fmha/fmha_fwd.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 28f7905734..7553ed0aaa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -248,11 +248,13 @@ float fmha_fwd_dispatch(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { +#if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < args.num_splits) { return fmha_fwd_splitkv(traits, args, config); } else +#endif { return fmha_fwd(traits, args, config); }