Implement DeviceGemmMultipleD_Wmma_CShuffleV3

This commit is contained in:
Anton Gorenko
2025-05-29 16:31:41 +05:00
parent 89ac60de6d
commit 137efa743d
3 changed files with 115 additions and 86 deletions

View File

@@ -10,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -21,13 +21,14 @@ namespace ck {
namespace tensor_operation {
namespace device {
/// @brief \"Universal\" GEMM operation with SplitK support.
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations applied to the A, B, and C tensors, respectively.
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -39,18 +40,20 @@ namespace device {
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam CDataType C tensor data type.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
@@ -104,11 +107,12 @@ namespace device {
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
@@ -122,15 +126,17 @@ namespace device {
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -158,39 +164,43 @@ template <typename ALayout,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
struct DeviceGemmMultipleD_Wmma_CShuffleV3
: public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
// GridwiseGemm
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -220,8 +230,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -285,8 +295,21 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, NumDTensor> size_ds_buffers;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
size_a_buffer,
size_b_buffer,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
@@ -298,7 +321,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
@@ -316,7 +339,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(CDataType),
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
@@ -419,8 +442,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
@@ -455,36 +478,33 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteA() override { return PermuteA; }
bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
@@ -494,35 +514,38 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
cde_element_op);
}
// polymorphic
@@ -548,12 +571,17 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemm_Wmma_CShuffleV3"
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< std::string(BLayout::name)[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
str << std::string(DLayout::name)[0];
});
str << std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "

View File

@@ -176,7 +176,6 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,

View File

@@ -64,8 +64,9 @@ __global__ void
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B/CDE_op are
/// elementwise operations that could be applied on each tensor respectively.
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -77,7 +78,7 @@ __global__ void
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
@@ -85,6 +86,7 @@ __global__ void
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
@@ -542,7 +544,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
}
template <typename DELayout>
__device__ static auto
__host__ __device__ static auto
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
@@ -620,7 +622,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ static auto MakeDsGridDescriptor_M_N(
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
@@ -747,9 +749,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}