[rocm-libraries] ROCm/rocm-libraries#4871 (commit 7d4c040)

[CK] Decouple EpilogueArgs from GridwiseGemm implementation (#4871)

This is duplicate of #4537. I could not re-open it since te target
branch got deleted and could not change the target branch since it was
closed... :)

## Motivation

Right now, all the Epilogues structs are declared inside the base
gridwise struct. They should be independent of it and the specialization
of the selected Epilogue Type should be declared within the the kernel
function.

## Technical Details

All Epilogue structs depend on template parameters that are known to the
base Gridwise Gemm struct. In this PR, we export them to be used
independently by any struct that might need to extract them. This
approach will serve the decoupling purposes for the Epilogues, but also
enable future constructs to use and expand this approach. See
30e2a4c01b64bdea68857c7badd9d7cffbf1adb9.

Right now an issue that arises is that when implementing a new Epilogue
Type, the developer is not forced to decide where this struct should/can
be used or not. To fix this I propose defining an `enum struct
EpilogueType` that will be used to fetch the Epilogue specialization
through a helper struct. See a943ac8d130e12d6843715b322181186e54ba15c.
Note that all the instantiation details will stay in this helper struct.
Also note the static assertion in the else statement.

## Test Plan

Test with existing CI, as nothing is added/removed.

## Test Result

All relevant existing CI tests should pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
chris-tsiaousis-hpc
2026-05-22 20:39:01 +02:00
committed by GitHub
parent b1975951d4
commit c7fac341de
21 changed files with 510 additions and 371 deletions

View File

@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; });
using EpilogueType = typename std::conditional<GridwiseOp::IsBWaveTransferApplicable &&
GridwiseOp::UseDirectStore,
typename GridwiseOp::EpilogueDirectStore,
typename GridwiseOp::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseOp::IsBWaveTransferApplicable && GridwiseOp::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseOp>;
constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte<EpilogueType>();
constexpr index_t LDS_size =
GridwiseOp::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const auto a_grid_desc_ak0_m_ak1 =
@@ -70,7 +73,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const auto b_grid_desc_bk0_n_bk1 =
GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_);
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
GridwiseOp::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set, TailNum>(
p_as_grid_batch,
p_bs_grid_batch,

View File

@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -61,8 +62,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -81,7 +84,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
});
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);

View File

@@ -12,6 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -50,9 +52,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
using SelectedEpilogue =
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -85,11 +89,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = EpilogueType(reduces_batch,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});
auto epilogue_args = SelectedEpilogue(reduces_batch,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,

View File

@@ -8,7 +8,7 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp"
#include <optional>
#include <type_traits>

View File

@@ -12,6 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -47,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
using SelectedEpilogue =
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = EpilogueType(
auto epilogue_args = SelectedEpilogue(
p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op);
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(

View File

@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -48,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueWelfordCShuffle>();
using SelectedEpilogue = get_epilogue_t<EpilogueType::WelfordCShuffle, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle(
auto epilogue_args = SelectedEpilogue(
p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N);
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
@@ -298,14 +300,16 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
}
using SelectedEpilogue = get_epilogue_t<EpilogueType::WelfordCShuffle, GridwiseGemmWelford>;
using LayernormMeanVarGridDesc_M_NBlock =
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
decltype(SelectedEpilogue::template MakeMeanVarDescriptor_M_N<
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
using LayernormCountGridDesc_M_NBlock =
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
decltype(SelectedEpilogue::template MakeCountDescriptor_M_N<
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(1, 1));
@@ -398,13 +402,13 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; });
layernorm_mean_var_grid_desc_m_nblock_ =
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
SelectedEpilogue::template MakeMeanVarDescriptor_M_N<
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
layernorm_count_grid_desc_m_nblock_ =
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
SelectedEpilogue::template MakeCountDescriptor_M_N<
Sequence<true, true>,
LayernormBlockTileSize_M_N::At(0),
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);

View File

@@ -12,6 +12,8 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -46,18 +48,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
using SelectedEpilogue =
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = EpilogueType(p_reduces_grid,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});
auto epilogue_args = SelectedEpilogue(p_reduces_grid,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);

View File

@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"

View File

@@ -19,6 +19,7 @@
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
@@ -72,9 +73,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>()];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = SelectedEpilogue{};
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
index_t left = 0;

View File

@@ -18,6 +18,7 @@
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
@@ -68,13 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,

View File

@@ -20,6 +20,7 @@
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
@@ -69,12 +70,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = SelectedEpilogue{};
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,

View File

@@ -19,6 +19,7 @@
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
@@ -68,12 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = SelectedEpilogue{};
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
AGridDesc_AK0_M_K1,

View File

@@ -21,6 +21,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
@@ -90,17 +91,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);

View File

@@ -20,6 +20,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -50,8 +51,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if defined(__gfx11__) || defined(__gfx12__)
using Epilogue = typename GridwiseGemm::EpilogueCShuffle;
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<Epilogue>()];
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
@@ -147,7 +151,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const index_t num_k_block_per_scale = GridwiseGemm::GetKBlockPerScale();
auto epilogue_args = Epilogue{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
@@ -155,7 +159,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
decltype(e_grid_desc),
decltype(a_scale_struct),
decltype(b_scale_struct),
Epilogue,
SelectedEpilogue,
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,

View File

@@ -22,6 +22,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const CDEElementwiseOperation cde_element_op)
{
#if defined(__gfx11__) || defined(__gfx12__)
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const index_t KBatch = 1;
@@ -139,13 +142,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
decltype(block_2_etile_map),
EpilogueType,
SelectedEpilogue,
1,
2>(
p_as_grid_,

View File

@@ -17,6 +17,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
@@ -66,12 +67,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const CDEElementwiseOperation cde_element_op)
{
#if(defined(__gfx11__) || defined(__gfx12__))
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ uint8_t p_shared[LDS_size];
const auto gemm_desc_ptr =
@@ -154,7 +157,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
gemm_desc_ptr[group_id].StrideE,
1);
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
constexpr TailNumber TailNum = TailNumber::Full;
if(has_main_k_block_loop)

View File

@@ -22,6 +22,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
namespace ck {
namespace tensor_operation {
@@ -94,12 +95,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
#if(defined(__gfx11__) || defined(__gfx12__))
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const index_t block_id = get_block_1d_id();
@@ -179,13 +182,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]);
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
GroupedGemmBlock2ETileMap,
EpilogueType,
SelectedEpilogue,
1, // Block2CTileMap MBlock index
2 // Block2CTileMap NBlock index
>(static_cast<void*>(p_shared),

View File

@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -41,12 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const index_t group_count)
{
#if(defined(__gfx11__) || defined(__gfx12__))
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
const index_t block_id = get_block_1d_id();
@@ -93,13 +96,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
auto epilogue_args = EpilogueType{};
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueType,
SelectedEpilogue,
1, // Block2CTileMap MBlock index
2 // Block2CTileMap NBlock index
>(static_cast<void*>(p_shared),

View File

@@ -0,0 +1,125 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
namespace ck {
enum class EpilogueType
{
CShuffle = 0,
DirectStore,
WelfordCShuffle,
ReduceCShuffle
};
template <EpilogueType type, typename GridwiseGemm, typename ReduceTrait = Tuple<>>
struct get_epilogue
{
private:
static constexpr auto get_epilogue_implementation()
{
static_assert((type == EpilogueType::ReduceCShuffle) ==
(!std::is_same_v<ReduceTrait, Tuple<>>),
"Provide a ReduceTrait only if the desired epilogue type is ReduceCShuffle.");
using TypeExtractor = typename GridwiseGemm::Traits;
if constexpr(type == EpilogueType::CShuffle)
{
return EpilogueCShuffle<
typename TypeExtractor::DsDataType_,
typename TypeExtractor::EDataType_,
typename TypeExtractor::AccDataType_,
typename TypeExtractor::CShuffleDataType_,
TypeExtractor::MPerBlock_,
TypeExtractor::NPerBlock_,
TypeExtractor::MPerWmma_,
TypeExtractor::NPerWmma_,
TypeExtractor::MRepeat_,
TypeExtractor::NRepeat_,
TypeExtractor::CShuffleMRepeatPerShuffle_,
TypeExtractor::CShuffleNRepeatPerShuffle_,
typename TypeExtractor::
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
typename TypeExtractor::CDEElementwiseOperation_,
typename TypeExtractor::ThisThreadBlock_,
typename TypeExtractor::BlockwiseGemmPipe_>{};
}
else if constexpr(type == EpilogueType::DirectStore)
{
return EpilogueDirectStore<typename TypeExtractor::DsDataType_,
typename TypeExtractor::EDataType_,
typename TypeExtractor::AccDataType_,
TypeExtractor::MRepeat_,
TypeExtractor::NRepeat_,
typename TypeExtractor::CDEElementwiseOperation_,
typename TypeExtractor::BlockwiseGemmPipe_>{};
}
else if constexpr(type == EpilogueType::WelfordCShuffle)
{
return EpilogueWelfordCShuffle<
typename TypeExtractor::DsDataType_,
typename TypeExtractor::EDataType_,
typename TypeExtractor::AccDataType_,
typename TypeExtractor::CShuffleDataType_,
TypeExtractor::MPerBlock_,
TypeExtractor::NPerBlock_,
TypeExtractor::MPerWmma_,
TypeExtractor::NPerWmma_,
TypeExtractor::MRepeat_,
TypeExtractor::NRepeat_,
TypeExtractor::CShuffleMRepeatPerShuffle_,
TypeExtractor::CShuffleNRepeatPerShuffle_,
typename TypeExtractor::
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
typename TypeExtractor::CDEElementwiseOperation_,
typename TypeExtractor::ThisThreadBlock_,
typename TypeExtractor::BlockwiseGemmPipe_,
TypeExtractor::BlockSize_>{{}, {}, {}, {}, {}};
}
else if constexpr(type == EpilogueType::ReduceCShuffle)
{
return EpilogueReduceCShuffle<
typename TypeExtractor::DsDataType_,
typename TypeExtractor::EDataType_,
typename TypeExtractor::AccDataType_,
typename TypeExtractor::CShuffleDataType_,
TypeExtractor::MPerBlock_,
TypeExtractor::NPerBlock_,
TypeExtractor::MPerWmma_,
TypeExtractor::NPerWmma_,
TypeExtractor::MRepeat_,
TypeExtractor::NRepeat_,
TypeExtractor::CShuffleMRepeatPerShuffle_,
TypeExtractor::CShuffleNRepeatPerShuffle_,
typename TypeExtractor::
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
typename TypeExtractor::CDEElementwiseOperation_,
typename TypeExtractor::ThisThreadBlock_,
typename TypeExtractor::BlockwiseGemmPipe_,
TypeExtractor::GemmSpec_,
TypeExtractor::BlockSize_,
ReduceTrait>{{}, {}, {}, {}, {}};
}
else
{
static_assert(false, "Not implemented for the specified type.");
}
}
public:
using Type = decltype(get_epilogue_implementation());
};
template <EpilogueType type, typename GridwiseGemm, typename ReduceTrait = Tuple<>>
using get_epilogue_t = typename get_epilogue<type, GridwiseGemm, ReduceTrait>::Type;
} // namespace ck

View File

@@ -19,10 +19,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
@@ -33,214 +29,6 @@
namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsBScaled = false,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// shift A matrices pointer for splitk
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
});
// shift B matrices pointer for splitk
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = EpilogueType{};
if constexpr(IsBScaled)
{
const long_index_t b_scale_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_a_scale_grid,
karg.p_b_scale_grid + b_scale_batch_offset +
splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared,
splitk_batch_offset,
karg,
epilogue_args,
0, /* A_k_id == 0 (we shift the pointer for splitk) */
k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
@@ -903,76 +691,30 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
false,
IsBPreShuffled>())>;
// Used to create obj in global function and pass it to Run method
using EpilogueCShuffle =
EpilogueCShuffle<DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>;
struct Traits
{
using DsDataType_ = DsDataType;
using EDataType_ = EDataType;
using AccDataType_ = AccDataType;
using CShuffleDataType_ = CShuffleDataType;
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ =
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
using CDEElementwiseOperation_ = CDEElementwiseOperation;
using ThisThreadBlock_ = ThisThreadBlock;
using BlockwiseGemmPipe_ = BlockwiseGemmPipe;
using EpilogueDirectStore = EpilogueDirectStore<DsDataType,
EDataType,
AccDataType,
MRepeat,
NRepeat,
CDEElementwiseOperation,
BlockwiseGemmPipe>;
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe,
BlockSize>;
template <typename ReduceTrait>
using EpilogueReduceCShuffle = EpilogueReduceCShuffle<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe,
GemmSpec,
BlockSize,
ReduceTrait>;
static constexpr auto MPerBlock_ = MPerBlock;
static constexpr auto NPerBlock_ = NPerBlock;
static constexpr auto MPerWmma_ = MPerWmma;
static constexpr auto NPerWmma_ = NPerWmma;
static constexpr auto MRepeat_ = MRepeat;
static constexpr auto NRepeat_ = NRepeat;
static constexpr auto CShuffleMRepeatPerShuffle_ = CShuffleMRepeatPerShuffle;
static constexpr auto CShuffleNRepeatPerShuffle_ = CShuffleNRepeatPerShuffle;
static constexpr auto GemmSpec_ = GemmSpec;
static constexpr auto BlockSize_ = BlockSize;
};
template <typename DEGridDesc>
__host__ __device__ static constexpr auto
@@ -1324,7 +1066,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
}
template <typename EpilogueType>
template <typename Epilogue>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
@@ -1346,11 +1088,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
max_lds_align)
: 0;
if constexpr(EpilogueType::IsLDSNeeded())
if constexpr(Epilogue::IsLDSNeeded())
{
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
EpilogueType::
Epilogue::
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
constexpr auto c_block_size =

View File

@@ -0,0 +1,223 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsBScaled = false,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
constexpr auto epilogue_type =
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
? EpilogueType::DirectStore
: EpilogueType::CShuffle;
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// shift A matrices pointer for splitk
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
});
// shift B matrices pointer for splitk
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = SelectedEpilogue{};
if constexpr(IsBScaled)
{
const long_index_t b_scale_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_a_scale_grid,
karg.p_b_scale_grid + b_scale_batch_offset +
splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = SelectedEpilogue{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared,
splitk_batch_offset,
karg,
epilogue_args,
0, /* A_k_id == 0 (we shift the pointer for splitk) */
k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
} // namespace ck