mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 08:18:26 +00:00
[rocm-libraries] ROCm/rocm-libraries#4871 (commit 7d4c040)
[CK] Decouple EpilogueArgs from GridwiseGemm implementation (#4871) This is duplicate of #4537. I could not re-open it since te target branch got deleted and could not change the target branch since it was closed... :) ## Motivation Right now, all the Epilogues structs are declared inside the base gridwise struct. They should be independent of it and the specialization of the selected Epilogue Type should be declared within the the kernel function. ## Technical Details All Epilogue structs depend on template parameters that are known to the base Gridwise Gemm struct. In this PR, we export them to be used independently by any struct that might need to extract them. This approach will serve the decoupling purposes for the Epilogues, but also enable future constructs to use and expand this approach. See 30e2a4c01b64bdea68857c7badd9d7cffbf1adb9. Right now an issue that arises is that when implementing a new Epilogue Type, the developer is not forced to decide where this struct should/can be used or not. To fix this I propose defining an `enum struct EpilogueType` that will be used to fetch the Epilogue specialization through a helper struct. See a943ac8d130e12d6843715b322181186e54ba15c. Note that all the instantiation details will stay in this helper struct. Also note the static assertion in the else statement. ## Test Plan Test with existing CI, as nothing is added/removed. ## Test Result All relevant existing CI tests should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
committed by
GitHub
parent
b1975951d4
commit
c7fac341de
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_batch(i) = karg.p_ds_grid_[i] + ds_batch_offset[i]; });
|
||||
|
||||
using EpilogueType = typename std::conditional<GridwiseOp::IsBWaveTransferApplicable &&
|
||||
GridwiseOp::UseDirectStore,
|
||||
typename GridwiseOp::EpilogueDirectStore,
|
||||
typename GridwiseOp::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseOp::IsBWaveTransferApplicable && GridwiseOp::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseOp>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseOp::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseOp::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
@@ -70,7 +73,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
GridwiseOp::MakeBGridDescriptor_BK0_N_BK1(karg.b_grid_desc_n_k_);
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set, TailNum>(
|
||||
p_as_grid_batch,
|
||||
p_bs_grid_batch,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -61,8 +62,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
@@ -81,7 +84,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
|
||||
});
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -50,9 +52,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
|
||||
using SelectedEpilogue =
|
||||
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
@@ -85,11 +89,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = EpilogueType(reduces_batch,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
karg.M,
|
||||
tensor_operation::element_wise::PassThrough{});
|
||||
auto epilogue_args = SelectedEpilogue(reduces_batch,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
karg.M,
|
||||
tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp"
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -47,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
|
||||
using SelectedEpilogue =
|
||||
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = EpilogueType(
|
||||
auto epilogue_args = SelectedEpilogue(
|
||||
p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -48,14 +49,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueWelfordCShuffle>();
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::WelfordCShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle(
|
||||
auto epilogue_args = SelectedEpilogue(
|
||||
p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -298,14 +300,16 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
|
||||
}
|
||||
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::WelfordCShuffle, GridwiseGemmWelford>;
|
||||
|
||||
using LayernormMeanVarGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
decltype(SelectedEpilogue::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
|
||||
using LayernormCountGridDesc_M_NBlock =
|
||||
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
decltype(SelectedEpilogue::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(1, 1));
|
||||
@@ -398,13 +402,13 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; });
|
||||
|
||||
layernorm_mean_var_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
|
||||
SelectedEpilogue::template MakeMeanVarDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
|
||||
layernorm_count_grid_desc_m_nblock_ =
|
||||
GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
|
||||
SelectedEpilogue::template MakeCountDescriptor_M_N<
|
||||
Sequence<true, true>,
|
||||
LayernormBlockTileSize_M_N::At(0),
|
||||
LayernormBlockTileSize_M_N::At(1)>(MRaw, gemm_nblock_);
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -46,18 +48,19 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
|
||||
using SelectedEpilogue =
|
||||
get_epilogue_t<EpilogueType::ReduceCShuffle, GridwiseGemm, ReduceTrait>;
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = EpilogueType(p_reduces_grid,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
karg.M,
|
||||
tensor_operation::element_wise::PassThrough{});
|
||||
auto epilogue_args = SelectedEpilogue(p_reduces_grid,
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
karg.M,
|
||||
tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common_kernels.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -72,9 +73,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(EGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>()];
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
index_t left = 0;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
@@ -68,13 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
|
||||
{
|
||||
#endif
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -69,12 +70,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
@@ -68,12 +69,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
const auto block_2_ctile_map_ = typename GridwiseGemm::Block2CTileMap{karg.M, karg.N, 4};
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
GridwiseGemm::template Run<GridwiseGemm::ConvRegime::BWD_WEIGHT,
|
||||
AGridDesc_AK0_M_K1,
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -90,17 +91,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -50,8 +51,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const ComputePtrOffset compute_ptr_offset_of_n)
|
||||
{
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
using Epilogue = typename GridwiseGemm::EpilogueCShuffle;
|
||||
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<Epilogue>()];
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
@@ -147,7 +151,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
const index_t num_k_block_per_scale = GridwiseGemm::GetKBlockPerScale();
|
||||
|
||||
auto epilogue_args = Epilogue{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
@@ -155,7 +159,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
decltype(e_grid_desc),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
Epilogue,
|
||||
SelectedEpilogue,
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -57,12 +58,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t KBatch = 1;
|
||||
@@ -139,13 +142,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
decltype(block_2_etile_map),
|
||||
EpilogueType,
|
||||
SelectedEpilogue,
|
||||
1,
|
||||
2>(
|
||||
p_as_grid_,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
@@ -66,12 +67,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ uint8_t p_shared[LDS_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
@@ -154,7 +157,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
1);
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
constexpr TailNumber TailNum = TailNumber::Full;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -94,12 +95,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -179,13 +182,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]);
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
EpilogueType,
|
||||
SelectedEpilogue,
|
||||
1, // Block2CTileMap MBlock index
|
||||
2 // Block2CTileMap NBlock index
|
||||
>(static_cast<void*>(p_shared),
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -41,12 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
@@ -93,13 +96,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
|
||||
auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
|
||||
auto epilogue_args = EpilogueType{};
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueType,
|
||||
SelectedEpilogue,
|
||||
1, // Block2CTileMap MBlock index
|
||||
2 // Block2CTileMap NBlock index
|
||||
>(static_cast<void*>(p_shared),
|
||||
|
||||
125
include/ck/tensor_operation/gpu/grid/epilogue_type.hpp
Normal file
125
include/ck/tensor_operation/gpu/grid/epilogue_type.hpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum class EpilogueType
|
||||
{
|
||||
CShuffle = 0,
|
||||
DirectStore,
|
||||
WelfordCShuffle,
|
||||
ReduceCShuffle
|
||||
};
|
||||
|
||||
template <EpilogueType type, typename GridwiseGemm, typename ReduceTrait = Tuple<>>
|
||||
struct get_epilogue
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_epilogue_implementation()
|
||||
{
|
||||
static_assert((type == EpilogueType::ReduceCShuffle) ==
|
||||
(!std::is_same_v<ReduceTrait, Tuple<>>),
|
||||
"Provide a ReduceTrait only if the desired epilogue type is ReduceCShuffle.");
|
||||
using TypeExtractor = typename GridwiseGemm::Traits;
|
||||
|
||||
if constexpr(type == EpilogueType::CShuffle)
|
||||
{
|
||||
return EpilogueCShuffle<
|
||||
typename TypeExtractor::DsDataType_,
|
||||
typename TypeExtractor::EDataType_,
|
||||
typename TypeExtractor::AccDataType_,
|
||||
typename TypeExtractor::CShuffleDataType_,
|
||||
TypeExtractor::MPerBlock_,
|
||||
TypeExtractor::NPerBlock_,
|
||||
TypeExtractor::MPerWmma_,
|
||||
TypeExtractor::NPerWmma_,
|
||||
TypeExtractor::MRepeat_,
|
||||
TypeExtractor::NRepeat_,
|
||||
TypeExtractor::CShuffleMRepeatPerShuffle_,
|
||||
TypeExtractor::CShuffleNRepeatPerShuffle_,
|
||||
typename TypeExtractor::
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
|
||||
typename TypeExtractor::CDEElementwiseOperation_,
|
||||
typename TypeExtractor::ThisThreadBlock_,
|
||||
typename TypeExtractor::BlockwiseGemmPipe_>{};
|
||||
}
|
||||
else if constexpr(type == EpilogueType::DirectStore)
|
||||
{
|
||||
return EpilogueDirectStore<typename TypeExtractor::DsDataType_,
|
||||
typename TypeExtractor::EDataType_,
|
||||
typename TypeExtractor::AccDataType_,
|
||||
TypeExtractor::MRepeat_,
|
||||
TypeExtractor::NRepeat_,
|
||||
typename TypeExtractor::CDEElementwiseOperation_,
|
||||
typename TypeExtractor::BlockwiseGemmPipe_>{};
|
||||
}
|
||||
else if constexpr(type == EpilogueType::WelfordCShuffle)
|
||||
{
|
||||
return EpilogueWelfordCShuffle<
|
||||
typename TypeExtractor::DsDataType_,
|
||||
typename TypeExtractor::EDataType_,
|
||||
typename TypeExtractor::AccDataType_,
|
||||
typename TypeExtractor::CShuffleDataType_,
|
||||
TypeExtractor::MPerBlock_,
|
||||
TypeExtractor::NPerBlock_,
|
||||
TypeExtractor::MPerWmma_,
|
||||
TypeExtractor::NPerWmma_,
|
||||
TypeExtractor::MRepeat_,
|
||||
TypeExtractor::NRepeat_,
|
||||
TypeExtractor::CShuffleMRepeatPerShuffle_,
|
||||
TypeExtractor::CShuffleNRepeatPerShuffle_,
|
||||
typename TypeExtractor::
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
|
||||
typename TypeExtractor::CDEElementwiseOperation_,
|
||||
typename TypeExtractor::ThisThreadBlock_,
|
||||
typename TypeExtractor::BlockwiseGemmPipe_,
|
||||
TypeExtractor::BlockSize_>{{}, {}, {}, {}, {}};
|
||||
}
|
||||
else if constexpr(type == EpilogueType::ReduceCShuffle)
|
||||
{
|
||||
return EpilogueReduceCShuffle<
|
||||
typename TypeExtractor::DsDataType_,
|
||||
typename TypeExtractor::EDataType_,
|
||||
typename TypeExtractor::AccDataType_,
|
||||
typename TypeExtractor::CShuffleDataType_,
|
||||
TypeExtractor::MPerBlock_,
|
||||
TypeExtractor::NPerBlock_,
|
||||
TypeExtractor::MPerWmma_,
|
||||
TypeExtractor::NPerWmma_,
|
||||
TypeExtractor::MRepeat_,
|
||||
TypeExtractor::NRepeat_,
|
||||
TypeExtractor::CShuffleMRepeatPerShuffle_,
|
||||
TypeExtractor::CShuffleNRepeatPerShuffle_,
|
||||
typename TypeExtractor::
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_,
|
||||
typename TypeExtractor::CDEShuffleBlockTransferScalarPerVectors_,
|
||||
typename TypeExtractor::CDEElementwiseOperation_,
|
||||
typename TypeExtractor::ThisThreadBlock_,
|
||||
typename TypeExtractor::BlockwiseGemmPipe_,
|
||||
TypeExtractor::GemmSpec_,
|
||||
TypeExtractor::BlockSize_,
|
||||
ReduceTrait>{{}, {}, {}, {}, {}};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not implemented for the specified type.");
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using Type = decltype(get_epilogue_implementation());
|
||||
};
|
||||
|
||||
template <EpilogueType type, typename GridwiseGemm, typename ReduceTrait = Tuple<>>
|
||||
using get_epilogue_t = typename get_epilogue<type, GridwiseGemm, ReduceTrait>::Type;
|
||||
|
||||
} // namespace ck
|
||||
@@ -19,10 +19,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
|
||||
@@ -33,214 +29,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsBScaled = false,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
using EpilogueType =
|
||||
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
||||
GridwiseGemm::UseDirectStore,
|
||||
typename GridwiseGemm::EpilogueDirectStore,
|
||||
typename GridwiseGemm::EpilogueCShuffle>::type;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
||||
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
// shift A matrices pointer for splitk
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
|
||||
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
|
||||
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = EpilogueType{};
|
||||
|
||||
if constexpr(IsBScaled)
|
||||
{
|
||||
const long_index_t b_scale_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset +
|
||||
splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
epilogue_args,
|
||||
0, /* A_k_id == 0 (we shift the pointer for splitk) */
|
||||
k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -903,76 +691,30 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
false,
|
||||
IsBPreShuffled>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
EpilogueCShuffle<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe>;
|
||||
struct Traits
|
||||
{
|
||||
using DsDataType_ = DsDataType;
|
||||
using EDataType_ = EDataType;
|
||||
using AccDataType_ = AccDataType;
|
||||
using CShuffleDataType_ = CShuffleDataType;
|
||||
using CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ =
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
|
||||
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
|
||||
using CDEElementwiseOperation_ = CDEElementwiseOperation;
|
||||
using ThisThreadBlock_ = ThisThreadBlock;
|
||||
using BlockwiseGemmPipe_ = BlockwiseGemmPipe;
|
||||
|
||||
using EpilogueDirectStore = EpilogueDirectStore<DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CDEElementwiseOperation,
|
||||
BlockwiseGemmPipe>;
|
||||
|
||||
using EpilogueWelfordCShuffle = EpilogueWelfordCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe,
|
||||
BlockSize>;
|
||||
|
||||
template <typename ReduceTrait>
|
||||
using EpilogueReduceCShuffle = EpilogueReduceCShuffle<
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
CDEElementwiseOperation,
|
||||
ThisThreadBlock,
|
||||
BlockwiseGemmPipe,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ReduceTrait>;
|
||||
static constexpr auto MPerBlock_ = MPerBlock;
|
||||
static constexpr auto NPerBlock_ = NPerBlock;
|
||||
static constexpr auto MPerWmma_ = MPerWmma;
|
||||
static constexpr auto NPerWmma_ = NPerWmma;
|
||||
static constexpr auto MRepeat_ = MRepeat;
|
||||
static constexpr auto NRepeat_ = NRepeat;
|
||||
static constexpr auto CShuffleMRepeatPerShuffle_ = CShuffleMRepeatPerShuffle;
|
||||
static constexpr auto CShuffleNRepeatPerShuffle_ = CShuffleNRepeatPerShuffle;
|
||||
static constexpr auto GemmSpec_ = GemmSpec;
|
||||
static constexpr auto BlockSize_ = BlockSize;
|
||||
};
|
||||
|
||||
template <typename DEGridDesc>
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -1324,7 +1066,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename EpilogueType>
|
||||
template <typename Epilogue>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1346,11 +1088,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
max_lds_align)
|
||||
: 0;
|
||||
|
||||
if constexpr(EpilogueType::IsLDSNeeded())
|
||||
if constexpr(Epilogue::IsLDSNeeded())
|
||||
{
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
EpilogueType::
|
||||
Epilogue::
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/epilogue_type.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename ComputePtrOffsetOfStridedBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
bool IsBScaled = false,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
// The normal approach to batching would be to increase the grid size by just stretching out
|
||||
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
|
||||
// functions not directly using the Z dimension for other calculations. As it turns out, k
|
||||
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
|
||||
// we will use the grid Y dimension for batching. This may be a bit fragile.
|
||||
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
|
||||
constexpr auto epilogue_type =
|
||||
GridwiseGemm::IsBWaveTransferApplicable && GridwiseGemm::UseDirectStore
|
||||
? EpilogueType::DirectStore
|
||||
: EpilogueType::CShuffle;
|
||||
using SelectedEpilogue = get_epilogue_t<epilogue_type, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
// shift A matrices pointer for splitk
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
|
||||
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
|
||||
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
|
||||
});
|
||||
|
||||
// shift B matrices pointer for splitk
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
|
||||
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ =
|
||||
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
|
||||
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
|
||||
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
|
||||
});
|
||||
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
if constexpr(IsBScaled)
|
||||
{
|
||||
const long_index_t b_scale_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset +
|
||||
splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_shift,
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
|
||||
using SelectedEpilogue = get_epilogue_t<EpilogueType::CShuffle, GridwiseGemm>;
|
||||
|
||||
constexpr index_t LDS_size =
|
||||
GridwiseGemm::template GetSharedMemoryNumberOfByte<SelectedEpilogue>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = SelectedEpilogue{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared,
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
epilogue_args,
|
||||
0, /* A_k_id == 0 (we shift the pointer for splitk) */
|
||||
k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user