Wmma support for gemm_bias_add_reduce (#3316)

* Add tests for gemm_bias_add_reduce

* Initial working implementation

* Generalize implementation of reduce epilogue

* Add tests for all layouts

* Add instances

* Fix test archs

* Fix xdl bug

* Remove library/profiler duplications

* Fix num_byted error profiler

* Fix typos

* Fix copyright

[ROCm/composable_kernel commit: aad4cf0985]
This commit is contained in:
Enrico Degregori
2026-01-07 19:27:16 +01:00
committed by GitHub
parent d074af36c9
commit 6eab5bea54
15 changed files with 1424 additions and 141 deletions

View File

@@ -0,0 +1,682 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ReduceTrait,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_bias_add_reduce_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg,
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops,
const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
using EpilogueType = typename GridwiseGemm::template EpilogueReduceCShuffle<ReduceTrait>;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args = EpilogueType(
p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M, d0_element_op);
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = p_reduces_grid;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = d0_element_op;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename BiasDataType,
typename D0DataType,
typename AccDataType,
typename CShuffleDataType,
typename ReduceAccDataType, // Reduce
typename ReducePtrsGlobal, // Reduce
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ElementwiseOperation,
typename ReduceOperations, // Reduce
typename ReduceInElementwiseOperations, // Reduce
typename ReduceAccElementwiseOperations, // Reduce
typename ReduceGlobalMemoryDataOperation, // Reduce
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector,
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock, // Reduce
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, // Reduce
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, // Reduce
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3
: public DeviceGemmReduce<1, ReduceOperations::Size()>
{
using CDEShuffleBlockTransferScalarPerVectors = Sequence<CShuffleBlockTransferScalarPerVector,
CShuffleBlockTransferScalarPerVector,
CShuffleBlockTransferScalarPerVector>;
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<ELayout, ELayout>,
ELayout,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
Tuple<BiasDataType, D0DataType>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,
D0ElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
ReduceGlobalMemoryDataOperation,
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
EDataType* p_e_grid,
const BiasDataType* p_bias_grid,
const D0DataType* p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ElementwiseOperation d0_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_e_grid_{p_e_grid},
p_bias_grid_{p_bias_grid},
p_d0_grid_{p_d0_grid},
p_reduces_grid_{p_reduces_grid},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw},
StrideA_{StrideA},
StrideB_{StrideB},
StrideC_{StrideC},
StrideC1_{StrideC1},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
d0_element_op_{d0_element_op},
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
{
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
EDataType* p_e_grid_;
const BiasDataType* p_bias_grid_;
const D0DataType* p_d0_grid_;
ReducePtrsGlobal p_reduces_grid_;
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
index_t StrideA_;
index_t StrideB_;
index_t StrideC_;
index_t StrideC1_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ElementwiseOperation d0_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 2>{arg.p_bias_grid_, arg.p_d0_grid_},
static_cast<EDataType*>(arg.p_e_grid_),
arg.MRaw_,
arg.NRaw_,
arg.KRaw_,
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
std::array<index_t, 2>{0, arg.StrideC1_}, // StrideDs
arg.StrideC_, // StrideE
1, // kbatch
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_};
if(stream_config.log_level_ > 0)
{
gemm_arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(gemm_arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.MRaw_, arg.NRaw_, 1);
float ave_time = 0;
index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.KRaw_);
const auto Run = [&](const auto& kernel) {
// Note: cache flushing not supported
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg,
arg.p_reduces_grid_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.d0_element_op_);
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(TailNum == TailNumber::Full)
{
const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3<
GridwiseGemm,
ReduceTrait,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline setting");
}
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(TailNum == TailNumber::Full)
{
const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3<
GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline v1 setting");
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(TailNum == TailNumber::Even)
{
const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3<
GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else if(TailNum == TailNumber::Odd)
{
const auto kernel = kernel_gemm_bias_add_reduce_wmma_cshuffle_v3<
GridwiseGemm,
ReduceTrait,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! Invalid pipeline v3 setting");
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Device implementation supports only gfx11 and gfx12! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "FP8 and BF8 not supported on gfx11! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Without padding, K must be divisible by AK1 and BK1! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
typename GridwiseGemm::Argument gemm_arg{
std::array<const void*, 1>{arg.p_a_grid_},
std::array<const void*, 1>{arg.p_b_grid_},
std::array<const void*, 2>{arg.p_bias_grid_, arg.p_d0_grid_},
static_cast<EDataType*>(arg.p_e_grid_),
arg.MRaw_,
arg.NRaw_,
arg.KRaw_,
std::array<index_t, 1>{arg.StrideA_}, // StrideAs
std::array<index_t, 1>{arg.StrideB_}, // StrideBs
std::array<index_t, 2>{0, arg.StrideC1_}, // StrideDs
arg.StrideC_, // StrideE
1, // kbatch
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_};
return GridwiseGemm::CheckValidity(gemm_arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static constexpr int NumReduce = ReduceOperations::Size();
static auto MakeArgument(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<EDataType*>(p_c),
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
std::array<ck::index_t, 1> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 1> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
index_t /* KBatch */ = 1) override
{
ReducePtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
Number<NumReduce>{});
AElementwiseOperation a_element_op =
*(static_cast<AElementwiseOperation*>(gemm_element_ops[0]));
BElementwiseOperation b_element_op =
*(static_cast<BElementwiseOperation*>(gemm_element_ops[1]));
CElementwiseOperation c_element_op =
*(static_cast<CElementwiseOperation*>(gemm_element_ops[2]));
D0ElementwiseOperation d_element_op =
*(static_cast<D0ElementwiseOperation*>(d_element_ops[0]));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<EDataType*>(p_c),
static_cast<const BiasDataType*>(p_bias),
static_cast<const D0DataType*>(p_ds[0]),
reduce_tuple,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideDs[0],
a_element_op,
b_element_op,
c_element_op,
d_element_op,
reduce_in_element_ops,
reduce_out_element_ops);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmBiasAddReduce_Wmma_CShuffleV3"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMRepeatPerShuffle << ", "
<< CShuffleNRepeatPerShuffle
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -49,8 +49,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
auto epilogue_args =
EpilogueType(p_reduces_grid, reduce_in_element_ops, reduce_out_element_ops, karg.M);
auto epilogue_args = EpilogueType(p_reduces_grid,
reduce_in_element_ops,
reduce_out_element_ops,
karg.M,
tensor_operation::element_wise::PassThrough{});
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args);
@@ -188,6 +191,7 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera
using ReduceTrait = ReduceTrait_<ReduceAccDataType,
ReducePtrsGlobal,
tensor_operation::element_wise::PassThrough,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,

View File

@@ -10,6 +10,7 @@ namespace ck {
template <typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
@@ -21,6 +22,7 @@ struct ReduceTrait_
{
using ReduceAccDataType_ = ReduceAccDataType;
using ReducePtrsGlobal_ = ReducePtrsGlobal;
using D0ElementwiseOperation_ = D0ElementwiseOperation;
using ReduceOperations_ = ReduceOperations;
using ReduceInElementwiseOperations_ = ReduceInElementwiseOperations;
using ReduceAccElementwiseOperations_ = ReduceAccElementwiseOperations;
@@ -148,11 +150,13 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReducePtrsGlobal_ p_reduces_grid_,
const typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops_,
const typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops_,
const index_t MRaw_)
const index_t MRaw_,
const typename ReduceTrait::D0ElementwiseOperation_ d0_element_op_)
: p_reduces_grid(p_reduces_grid_),
reduce_in_element_ops(reduce_in_element_ops_),
reduce_out_element_ops(reduce_out_element_ops_),
MRaw(MRaw_),
d0_element_op{d0_element_op_},
reduce_grid_desc_m{MakeReduceGridDescriptor_M(MRaw)}
{
}
@@ -174,6 +178,13 @@ struct EpilogueReduceCShuffle
const index_t& block_m_id,
const index_t& block_n_id)
{
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
auto reduce_grid_desc_mblock_mperblock =
MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m);
@@ -216,29 +227,6 @@ struct EpilogueReduceCShuffle
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
GetCShuffleLDSDescriptor();
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// Thread transfer LDS to Vmem
auto cde_shuffle_block_copy_lds_and_global =
Base::template GetLDSToVmemEpilogueDescriptor<EGlobalMemoryDataOperation, EDataType>(
c_ds_desc_refs,
e_grid_desc_mblock_mperblock_nblock_nperblock,
cde_element_op,
block_m_id,
block_n_id);
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
@@ -346,6 +334,68 @@ struct EpilogueReduceCShuffle
},
Number<NumReduce>{});
// multiple Ds
constexpr auto d_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
constexpr auto ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock = generate_tuple(
[&](auto) { return d_reduce_thread_desc_mblock_mperblock_nblock_nperblock; },
Number<NumDTensor>{});
constexpr auto ds_thread_buf_size =
d_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
auto c01_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, typename ReduceTrait::ReduceAccDataType_>(
Number<ds_thread_buf_size>{});
auto ds_thread_copy_global_to_vgpr = generate_tuple(
[&](auto I) {
return ThreadwiseTensorSliceTransfer_v2<
remove_cvref_t<tuple_element_t<I.value, DsDataType>>,
typename ReduceTrait::ReduceAccDataType_,
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
remove_cvref_t<
decltype(ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I])>,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
1,
true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
make_multi_index(
I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
},
Number<NumDTensor>{});
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
// Write E from Vgpr to Vmem
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
typename ReduceTrait::ReduceAccDataType_,
EDataType,
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
ReduceTrait::CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock_,
EGlobalMemoryDataOperation,
1,
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
NumDTensor > 0 ? tensor_operation::element_wise::PassThrough{} : cde_element_op};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
@@ -365,15 +415,6 @@ struct EpilogueReduceCShuffle
// make sure it's safe to read from LDS
block_sync_lds();
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
{
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf,
@@ -381,6 +422,53 @@ struct EpilogueReduceCShuffle
make_tuple(I0, I0),
c_reduce_thread_buf);
// Note: currently multiple Ds supports only Bias + Add.
// It needs to be generalized for other operations (currently not needed)
if constexpr(NumDTensor > 0)
{
auto& d0_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I0);
// d0 / d1 operations
d0_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I0],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
// c = activation(c + bias)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
typename ReduceTrait::ReduceAccDataType_ out;
cde_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
c_reduce_thread_buf(i) = out;
});
auto& d1_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I1);
d1_thread_copy_global_to_vgpr.Run(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
ds_reduce_thread_desc_mblock_mperblock_nblock_nperblock[I1],
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
// c = c + c1_function(c1)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
d0_element_op(c01_thread_buf(i), c01_thread_buf(i));
c_reduce_thread_buf(i) += c01_thread_buf(i);
});
}
// Write E
c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c_reduce_thread_buf,
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
// Reduction
static_for<0, NumReduce, 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
@@ -448,14 +536,15 @@ struct EpilogueReduceCShuffle
{
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
static_for<0, NumDTensor, 1>{}([&](auto I) {
auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(I);
d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
ds_grid_desc_mblock_mperblock_nblock_nperblock[I], cde_global_step);
});
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, cde_global_step);
}
});
}
@@ -464,6 +553,7 @@ struct EpilogueReduceCShuffle
typename ReduceTrait::ReduceInElementwiseOperations_ reduce_in_element_ops;
typename ReduceTrait::ReduceAccElementwiseOperations_ reduce_out_element_ops;
index_t MRaw;
typename ReduceTrait::D0ElementwiseOperation_ d0_element_op;
ReduceGridDesc_M reduce_grid_desc_m;
};

View File

@@ -897,6 +897,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),

View File

@@ -19,6 +19,7 @@ namespace instance {
using DeviceGemmAddAddMeanSquareMeanPtr = ck::tensor_operation::device::DeviceGemmReducePtr<1, 2>;
#if defined(CK_USE_XDL)
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
@@ -27,6 +28,18 @@ void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAddAddMeanSquareMeanPtr>&);
#endif // CK_USE_WMMA
template <typename ADataType,
typename BDataType,
@@ -45,33 +58,61 @@ auto get_device_gemm_add_add_mean_squaremean_instances()
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
op_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#if defined(CK_USE_XDL)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
#endif
#if defined(CK_USE_WMMA)
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
op_ptrs);
#endif
}
}

View File

@@ -1,10 +1,15 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_bias_add_reduce_instance
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm|
//#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer|
//#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | |
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 2, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 2, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 2, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 2, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 2, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 2, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 2, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>
// // clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
// c[m, n] = a[k, m] * b[n, k]
using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances =
std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm|
//#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer|
//#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | |
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 2, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 2, 8, 16, 16, 2, 8, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 2, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 2, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 2, 8, 16, 16, 4, 2, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 2, 8, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 2, 8, 16, 16, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances =
std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm|
//#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer|
//#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | |
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 2, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 2, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, S<16, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,81 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
// c[m, n] = a[m, k] * b[n, k]
using device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances =
std::tuple<
// clang-format off
//#####################################| ALayout| BLayout| ELayout|AData| BData| EData|BiasData|D0Data| AccData| CShuffle| ReduceAcc| ReducePtrsGlobal| A| B| C| D0| ReduceOperations| Reduce| Reduce| Reduce| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| BlkGemm| BlkGemm|
//#####################################| | | | Type| Type| Type| Type| Type| Type| DataType| DataType| | Elementwise| Elementwise| Elementwise| Elementwise| | InElementwiseOperations| OutElementwiseOperations| GlobalMemory| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| PipeSched| PipelineVer|
//#####################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | DataOperation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| | |
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<32, 8>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, S<32, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 64, 32, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8, S<32, 2>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGemmBiasAddReduce_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceMemOp, GemmDefault, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, S<16, 4>, 4, 1, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
void add_device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReducePtr<1, ReduceOps::Size()>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_mean_squaremean_wmma_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,6 +9,8 @@
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
@@ -17,40 +19,6 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
using F16 = ck::half_t;
using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<Identity, Square>;
using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmBiasAddReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>;
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
@@ -63,7 +31,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
void profile_gemm_bias_add_reduce_impl(int do_verification,
bool profile_gemm_bias_add_reduce_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
@@ -75,6 +43,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
int StrideC,
int StrideD0)
{
bool pass = true;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
@@ -231,47 +201,19 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
bias_device_buf.ToDevice(bias_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
// get device op instances
const auto op_ptrs =
ck::tensor_operation::device::instance::get_device_gemm_add_add_mean_squaremean_instances<
ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout>();
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::instance::
add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
}
}
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(gemm_ptrs.size() <= 0)
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
@@ -282,29 +224,29 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
float best_gb_per_sec = 0;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
gemm_element_ops,
{&d0_element_op},
reduce_in_element_ops,
reduce_out_element_ops);
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
gemm_element_ops,
{&d0_element_op},
reduce_in_element_ops,
reduce_out_element_ops);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// init DO, D1 to 0
reduce0_device_buf.SetZero();
@@ -313,12 +255,12 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::string gemm_name = gemm_ptr->GetTypeString();
std::string gemm_name = op_ptr->GetTypeString();
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
@@ -343,9 +285,13 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
if(!pass)
{
std::cout << op_ptr->GetTypeString() << " failed" << std::endl;
}
if(do_log)
{
@@ -372,12 +318,14 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
std::cout << op_ptr->GetTypeString() << " does not support this GEMM problem"
<< std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
}
} // namespace profiler

View File

@@ -258,6 +258,7 @@ add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_add)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(gemm_blockscale_wp)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_multi_abd)

View File

@@ -0,0 +1,9 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance)
endif()
endif()

View File

@@ -0,0 +1,106 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_common.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_MK_NK
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_MK_KN
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_KM_KN
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_KM_NK
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple< F16, F16, F16, F16, F16, F32>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_NK, KernelTypes);
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_KN, Regular)
{
std::vector<int> Ms{256};
constexpr int N = 512;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_NK, Regular)
{
std::vector<int> Ms{256};
constexpr int N = 1024;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_bias_add_reduce_impl.hpp"
namespace ck {
namespace test {
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
template <typename Tuple>
class TestGemmBiasAddReduceCommon : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using CDataType = std::tuple_element_t<4, Tuple>;
using BiasDataType = std::tuple_element_t<5, Tuple>;
using D0DataType = std::tuple_element_t<6, Tuple>;
using ReduceDataType = std::tuple_element_t<7, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void Run(const int M, const int N, const int K)
{
bool all_success = true;
int StrideA = std::is_same_v<remove_cvref_t<ALayout>, Row> ? K : M;
int StrideB = std::is_same_v<remove_cvref_t<BLayout>, Row> ? N : K;
int StrideD0 = std::is_same_v<remove_cvref_t<CLayout>, Row> ? N : M;
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
all_success =
all_success &
ck::profiler::profile_gemm_bias_add_reduce_impl<ADataType,
BDataType,
CDataType,
BiasDataType,
D0DataType,
ReduceDataType,
ALayout,
BLayout,
CLayout>(
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, StrideD0);
EXPECT_TRUE(all_success);
}
};
} // namespace test
} // namespace ck