Wmma support for gemm_reduce (#3145)

* Initial implementation GEMM+Reduce:

 - device struct
 - epilogue struct

* Fix tests, improve profiler and add initial instances

* Add instances

* Fix compilation error

* Address review comments

* Fix logging

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-11-12 20:23:54 +01:00
committed by GitHub
parent 299c9bca1b
commit 7414a0f4d4
12 changed files with 1568 additions and 12 deletions

View File

@@ -0,0 +1,661 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ReduceTrait,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_reduce_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg,
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args =
EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M);
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = p_reduces_grid;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename ReduceAccDataType, // Reduce
typename ReducePtrsGlobal, // Reduce
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ReduceOperations, // Reduce
typename ReduceInElementwiseOperations, // Reduce
typename ReduceAccElementwiseOperations, // Reduce
typename ReduceGlobalMemoryDataOperation, // Reduce
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, // Reduce
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, // Reduce
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, // Reduce
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOperations::Size()>
{
using CDEShuffleBlockTransferScalarPerVectors =
Sequence<CDEShuffleBlockTransferScalarPerVector,
CDEShuffleBlockTransferScalarPerVector,
CDEShuffleBlockTransferScalarPerVector>;
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>,
ELayout,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
ReduceGlobalMemoryDataOperation,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
EDataType* p_c_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_reduces_grid_{p_reduces_grid},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw},
StrideA_{StrideA},
StrideB_{StrideB},
StrideC_{StrideC},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
EDataType* p_c_grid_;
ReducePtrsGlobal p_reduces_grid_;
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
index_t StrideA_;
index_t StrideB_;
index_t StrideC_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},
static_cast<EDataType*>(arg.p_c_grid_),
arg.MRaw_,
arg.NRaw_,
arg.KRaw_,
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
std::array<index_t, 0>{}, // StrideDs
arg.StrideC_, // StrideE
1, // kbatch
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_};
if(stream_config.log_level_ > 0)
{
gemm_arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(gemm_arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1);
float ave_time = 0;
index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_);
const auto Run = [&](const auto& kernel) {
// Note: cache flushing not supported
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg,
arg.p_reduces_grid_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_);
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(TailNum == TailNumber::Full)
{
const auto kernel =
kernel_gemm_reduce_wmma_cshuffle_v3<GridwiseGemm,
ReduceTrait,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline setting");
}
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(TailNum == TailNumber::Full)
{
const auto kernel =
kernel_gemm_reduce_wmma_cshuffle_v3<GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline v1 setting");
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(TailNum == TailNumber::Even)
{
const auto kernel =
kernel_gemm_reduce_wmma_cshuffle_v3<GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else if(TailNum == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_reduce_wmma_cshuffle_v3<GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline v3 setting");
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 0>{},
static_cast<EDataType*>(arg.p_c_grid_),
arg.MRaw_,
arg.NRaw_,
arg.KRaw_,
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
std::array<index_t, 0>{}, // StrideDs
arg.StrideC_, // StrideE
1, // kbatch
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_};
return GridwiseGemm::CheckValidity(gemm_arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<EDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
ck::index_t = 1) override
{
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<EDataType*>(p_c),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmReduce_Wmma_CShuffleV3"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMRepeatPerShuffle << ", "
<< CShuffleNRepeatPerShuffle
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,470 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace ck {
template <typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
struct ReduceTrait_
{
using ReduceAccDataType_ = ReduceAccDataType;
using ReducePtrsGlobal_ = ReducePtrsGlobal;
using ReduceOperations_ = ReduceOperations;
using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
using ReduceGlobalMemoryDataOperation_ = ReduceGlobalMemoryDataOperation;
using CReduceThreadClusterLengths_MPerBlock_NPerBlock_ =
CReduceThreadClusterLengths_MPerBlock_NPerBlock;
static constexpr index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_ =
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock;
static constexpr index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_ =
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock;
};
template <typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
typename CDEElementwiseOperation,
typename ThisThreadBlock,
typename BlockwiseGemmPipe,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
typename ReduceTrait>
struct EpilogueReduceCShuffle
: EpilogueCShuffleBase<DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>
{
using Base = EpilogueCShuffleBase<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe>;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::GetCShuffleLDSDescriptor;
using Base::GetVgprToLDSEpilogueDescriptor;
using Base::I0;
using Base::I1;
using Base::I3;
using Base::NumDTensor;
// assume Reduce is packed tensor
__device__ static auto MakeReduceGridDescriptor_M(index_t MRaw)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return d_grid_desc_mraw;
}
}
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
__device__ static constexpr auto
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return reduce_grid_desc_mblock_mperblock;
}
__device__ EpilogueReduceCShuffle(
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
const index_t MRaw_)
: p_reduces_grid(p_reduces_grid_),
reduce_in_element_ops(reduce_in_element_ops_),
reduce_out_element_ops(reduce_out_element_ops_),
MRaw(MRaw_),
reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)}
{
}
template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename CThreadBuf,
typename DsGridPointer,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ void Run(CThreadBuf& c_thread_buf,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
CDEElementwiseOperation& cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id)
{
auto reduce_grid_desc_mblock_mperblock =
MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
BlockwiseGemmPipe::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// LDS buffer
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize());
// Thread transfer Vgpr to LDS
auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
// Space Filling Curve Vgpr
constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
// Space Filling Curve Vmem
constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
// Block descriptor
constexpr auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_tuple(
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
I1)),
make_freeze_transform(I0),
make_pass_through_transform(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
I3))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
static_assert(
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) *
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
BlockSize,
"wrong!");
static_assert(
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) %
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0) ==
0 &&
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) %
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1) ==
0,
"wrong!");
constexpr index_t mreduce_per_thread =
(CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I0);
constexpr index_t nreduce_per_thread =
(CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_::At(I1);
static constexpr index_t NumReduce = ReduceTrait::ReducePtrsGlobal_::Size();
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
Sequence<mreduce_per_thread, nreduce_per_thread>{};
// VGPR c_reduce_thread_desc_mperblock_nperblock
constexpr auto c_reduce_thread_desc_mperblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
// VGPR reduce_thread_desc_mperblock
constexpr auto reduce_thread_desc_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
// VGPR reduce_thread_desc_mblock_mperblock
constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
typename ReduceTrait::CReduceThreadClusterLengths_MPerBlock_NPerBlock_{},
Sequence<1, 0>{});
const auto c_reduce_thread_cluster_idx = c_reduce_thread_cluster_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto c_reduce_thread_data_idx_begin =
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
CShuffleDataType,
typename ReduceTrait::ReduceAccDataType_,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
Sequence<0, 1>,
1,
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
1,
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_reduce_grid = p_reduces_grid[I];
auto reduce_acc_element_op = reduce_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3<
typename ReduceTrait::ReduceAccDataType_,
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
ReduceTrait::CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock_,
ReduceTrait::ReduceGlobalMemoryDataOperation_::At(I),
1,
false>{reduce_grid_desc_mblock_mperblock,
make_multi_index(block_m_id, // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
reduce_acc_element_op};
},
Number<NumReduce>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
// CShuffle and Store
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
{
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
c_reduce_thread_desc_mperblock_nperblock,
make_tuple(I0, I0),
c_reduce_thread_buf);
static_for<0, NumReduce, 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr,
typename ReduceTrait::ReduceAccDataType_>(
reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& reduce_thread_copy_vgpr_to_global =
reduce_tuple_thread_copy_vgpr_to_global(In);
using ReduceOperation =
remove_cvref_t<decltype(typename ReduceTrait::ReduceOperations_{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<typename ReduceTrait::ReduceAccDataType_,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(reduce_thread_desc_mperblock),
ReduceOperation,
false>;
// Global write Gemm shuffle + reduction
const auto reduce_identityVal = ReduceOperation::template GetIdentityValue<
typename ReduceTrait::ReduceAccDataType_>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
// copy from VGPR to Global
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_cde_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
}
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid;
typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
index_t MRaw;
ReduceGridDesc_M reduce_grid_desc_m;
};
} // namespace ck

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"

View File

@@ -25,6 +25,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
namespace ck {
@@ -622,6 +623,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BlockwiseGemmPipe,
BlockSize>;
template <typename ReduceTrait>
using EpilogueReduceCShuffle = EpilogueReduceCShuffle<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe,
GemmSpec,
BlockSize,
ReduceTrait>;
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)