Merge branch 'develop' into ginolu/add_wgmfma_dispatcher

This commit is contained in:
Gino Lu
2025-09-08 19:10:23 -05:00
147 changed files with 10722 additions and 1484 deletions

View File

@@ -149,50 +149,105 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
#endif
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is
/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef CK_CODE_GEN_RTC
virtual std::unique_ptr<BaseArgument>
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
1, // KBatch
a_element_op,
b_element_op,
cde_element_op);
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
virtual int GetPreShuffleParameters() = 0;
#endif
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device

View File

@@ -40,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
@@ -62,14 +62,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
#if defined(__gfx11__)
}
#endif
@@ -272,11 +276,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -311,7 +317,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -336,17 +342,25 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
index_t BatchStrideC_,
index_t Batch_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
std::array<const void*, 0>{}, // p_ds_grid_
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
k_batch_,
a_element_op_,
b_element_op_,
cde_element_op_,
is_reduce_),
Batch(Batch_),
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
@@ -443,7 +457,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg_.p_c_grid,
hipMemsetAsync(arg_.p_e_grid,
0,
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -469,7 +483,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg.p_c_grid,
hipMemsetAsync(arg.p_e_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -658,7 +672,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BatchStrideB,
BatchStrideC,
Batch,
1 /* KBatch */};
1, /* KBatch */
AElementwiseOperation{},
BElementwiseOperation{},
CElementwiseOperation{}};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -694,7 +711,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BatchStrideB,
BatchStrideC,
Batch,
1); // KBatch
1,
AElementwiseOperation{},
BElementwiseOperation{},
CElementwiseOperation{}); // KBatch
}
// polymorphic

View File

@@ -0,0 +1,410 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultipleD_Wmma_CShuffleV3
: public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
DsDataType,
EDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
str << std::string(DLayout::name)[0];
});
str << std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -177,15 +177,16 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -220,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -230,21 +231,24 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<>,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
@@ -275,11 +279,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -295,20 +313,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch);
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic

View File

@@ -89,11 +89,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -130,7 +132,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -140,21 +142,24 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<>,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
@@ -188,23 +193,25 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -228,12 +235,14 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale),

View File

@@ -24,7 +24,8 @@ namespace device {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
@@ -32,6 +33,7 @@ template <typename GridwiseGemm,
index_t AK1,
index_t BK1,
GemmSpecialization GemmSpec,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -95,8 +97,22 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
size_a_buffer,
size_b_buffer,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
@@ -106,9 +122,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
@@ -124,9 +140,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(CDataType),
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
@@ -149,6 +165,16 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}();
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
// be odd.
constexpr bool AtomicsImplementationExists =
!(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>) ||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
if(has_main_k_block_loop)
{
// Tail number always full
@@ -157,12 +183,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
@@ -186,12 +215,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
@@ -229,8 +261,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{

View File

@@ -47,7 +47,7 @@ struct Add
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
y = x0 + type_convert<float>(x1);
};
template <>

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
@@ -22,9 +22,10 @@ namespace ck {
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -36,18 +37,20 @@ namespace ck {
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
@@ -105,11 +108,12 @@ namespace ck {
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
@@ -123,15 +127,17 @@ namespace ck {
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -161,8 +167,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -173,15 +179,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -211,8 +219,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -223,15 +231,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -261,8 +271,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -297,17 +307,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
<< ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
<< ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ static index_t GetKBlockPerScale() { return 1; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -20,15 +20,17 @@ namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale N
@@ -60,11 +62,11 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
@@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
}
index_t M;
@@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
const CDEElementwiseOperation cde_element_op;
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
@@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
// NOTE: Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -19,7 +19,7 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
@@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg);
#if defined(__gfx11__)
@@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -92,8 +94,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -112,6 +114,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
@@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
}
template <typename DELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
}
}();
@@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
@@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
NRepeat,
KPack>())>;
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
@@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
is_same<remove_cvref_t<EDataType>, float>::value ||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
if(karg.IsAtomicAdd() && karg.KBatch > 1)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
if(karg.KBatch > 1)
{
return false;
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename BScaleStruct,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
@@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
@@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
cde_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
@@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_cde_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
@@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}

View File

@@ -165,6 +165,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
oob_val = oob_val & is_src_valid;
// TODO: With column-major matrices this step restricts the transferred tensor slice
// to just one element, which consequently prevents using atomic operations if the
// matrix data type is on 16 bits.
if constexpr(SrcScalarPerVectors{}[i] == 1)
{
auto data_types = SrcDatas{};

View File

@@ -270,8 +270,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -390,8 +390,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -793,6 +793,9 @@ struct WmmaGemm
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
// Integer wmma operators need extra input flags to indicate if the input is signed or
// unsigned. At the moment CK supports only signed integer inputs, so these flags are
// hardcoded.
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);

View File

@@ -162,9 +162,15 @@ CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
using Base = typename tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>::Base;
print(static_cast<const Base&>(descriptor));
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");

View File

@@ -91,7 +91,7 @@ struct Default2DEpilogue
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* = nullptr)
void* = nullptr) const
{
const auto storeOrUpdateTile = [&](const auto& o_tile) {
// TODO: this is ugly

View File

@@ -103,27 +103,41 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
@@ -131,8 +145,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
lse_lds_ptr0,
lse_lds_ptr1,
d_lds_ptr0,
d_lds_ptr1,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
@@ -156,8 +172,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
LSEDataType* __restrict__ lse_lds_ptr0,
LSEDataType* __restrict__ lse_lds_ptr1,
DDataType* __restrict__ d_lds_ptr0,
DDataType* __restrict__ d_lds_ptr1,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
@@ -389,38 +407,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
@@ -471,27 +489,31 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
OGradDataType* const __restrict__ do_lds_ptr_next,
LSEDataType* const __restrict__ lse_lds_ptr_curr,
LSEDataType* const __restrict__ lse_lds_ptr_next,
DDataType* const __restrict__ d_lds_ptr_curr,
DDataType* const __restrict__ d_lds_ptr_next
) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
if constexpr(is_prologue)
{
lse_block_tile = load_tile(lse_dram_window);
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
async_load_tile(lse_lds_write_window, lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_block_tile = load_tile(d_dram_window);
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
async_load_tile(d_lds_write_window, d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
@@ -510,6 +532,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
}
if constexpr(is_epilogue)
{
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
lse = load_tile(lse_lds_read_window);
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
@@ -617,11 +646,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_prologue)
{
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
}
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
@@ -676,13 +700,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
s_waitcnt</*vmcnt=*/0>();
block_sync_lds();
if constexpr(is_prologue)
{
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -720,7 +743,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
@@ -749,17 +771,25 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next);
do_lds_ptr_next,
lse_lds_ptr_curr,
lse_lds_ptr_next,
d_lds_ptr_curr,
d_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
@@ -707,18 +707,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
dk_epilogue(dk_dram_window, dk_acc);
dk_epilogue(dk_dram_window, dk_acc, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, dv_acc);
dv_epilogue(dv_dram_window, dv_acc, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
};
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -740,9 +740,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
const auto seqlen_kv_length = k_length.at(number<0>{});
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -752,8 +752,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
dq_acc);
else
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
// static_assert(kIsDeterministic);
dq_epilogue(dq_dram_window, dq_acc);
dq_epilogue(dq_dram_window, dq_acc, nullptr);
return;
}
};

View File

@@ -194,13 +194,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentOGrad<Problem>();
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
}
template <typename Problem>
@@ -358,11 +352,30 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kVHeaddim>();
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
BlockGemm>();
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N0 = MWarp * NWarp;
constexpr index_t M1 = kMPerBlock;
constexpr index_t M0 = get_warp_size() / M1;
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
"M1 must be a factor of warp size");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, M0>,
tuple<sequence<M1, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1>,
sequence<1>>{});
}
template <typename Problem>
@@ -793,9 +806,10 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
return lsed_lds_block_desc;
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
@@ -984,15 +998,16 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return static_cast<index_t>(max( //
sizeof(int) * get_warp_size(),
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
return sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return GetSmemSizeLSE<Problem>();
}
template <typename Problem>
@@ -1039,8 +1054,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
smem_size_d + max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
smem_size_lse * 2 + smem_size_d * 2 +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0, smem_size_stage1);
}
@@ -1090,6 +1106,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
@@ -1116,11 +1134,12 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
public:
static constexpr index_t TOTAL_VMEM_READ =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
@@ -1128,7 +1147,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
static_for<0, lcm_inst, 1>{}([&](auto i) {
@@ -1161,8 +1180,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
@@ -1185,7 +1204,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);

View File

@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);

View File

@@ -266,6 +266,10 @@ struct GroupedGemmKernel
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
"SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
@@ -282,11 +286,15 @@ struct GroupedGemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel)
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -296,9 +304,11 @@ struct GroupedGemmKernel
splitk_batch_offset,
i_m,
i_n);
return;
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
@@ -311,14 +321,14 @@ struct GroupedGemmKernel
i_n);
}
}
else
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
@@ -438,17 +448,34 @@ struct GroupedGemmKernel
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run GEMM pipeline with compile-time branching
const auto& c_block_tile = [&]() {
if constexpr(GemmPipeline::Preshuffle)
{
// Preshuffle version - without has_hot_loop parameter
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
else
{
// Regular version - with has_hot_loop parameter
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
@@ -491,8 +518,9 @@ struct GroupedGemmKernel
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,

View File

@@ -43,7 +43,7 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = 0>
bool Preshuffle_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;

View File

@@ -296,6 +296,73 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG_::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"

View File

@@ -769,12 +769,11 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{

View File

@@ -0,0 +1,433 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/host.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
/// @brief The Grouped GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct QuantGroupedGemmHostArgs
{
CK_TILE_HOST QuantGroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_,
index_t stride_AQ_,
index_t stride_BQ_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
bq_ptr(bq_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
QK_A(QK_A_),
QK_B(QK_B_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_AQ(stride_AQ_),
stride_BQ(stride_BQ_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
const void* bq_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t QK_A;
index_t QK_B;
index_t stride_A;
index_t stride_B;
index_t stride_AQ;
index_t stride_BQ;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
using QuantGroupedGemmKernelArgs = QuantGemmKernelArgs;
struct QuantGemmTransKernelArg
{
QuantGroupedGemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
QuantGemmTransKernelArg() = delete;
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
{
}
};
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
QuantType QuantType_>
struct QuantGroupedGemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
//// @brief Specify the layout configurations for A, B, C/E
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
using AQDataType =
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
using BQDataType =
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
static constexpr auto kQuantType = QuantType_;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel =
QuantGroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
{
return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
{
return group_count * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto BlockSize() -> dim3
{
if(is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
int occupancy;
HIP_CHECK_ERROR(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
{
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += local_grid_size * it_desc.k_batch;
}
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
-> std::vector<QuantGemmTransKernelArg>
{
std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
index_t grid_size = 0;
gemm_kernel_args_.reserve(group_count);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M;
const index_t N = gemm_descs[i].N;
const index_t K = gemm_descs[i].K;
if(M == 0 || N == 0 || K == 0)
{
continue;
}
const index_t stride_a = gemm_descs[i].stride_A;
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_e = gemm_descs[i].stride_C;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp;
auto karg =
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].e_ptr),
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
gemm_descs[i].k_batch,
M,
N,
K,
gemm_descs[i].QK_A,
gemm_descs[i].QK_B,
stride_a,
stride_b,
stride_e,
gemm_descs[i].stride_AQ,
gemm_descs[i].stride_BQ};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
return gemm_kernel_args_;
}
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
{
for(const auto& karg : kargs)
{
if(!Base::IsSupportedArgument(karg.group_karg))
{
return false;
}
}
return true;
}
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
static_assert(GemmPipeline::DoubleSmemBuffer == false,
"DoubleSmemBuffer needs to be false");
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
RunGemmWithPipelineSelection(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
* and the tail-number. This is needed for the persistent tile-loop when
* we didn't have access to the K dimension on the host.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
}
// For persistent kernels
template <bool U = UsePersistentKernel,
typename = std::enable_if_t<U>,
typename = void> // extra template parameter to avoid redefinition
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const
{
const index_t grid_size = ck_tile::get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
index_t cum_grid_size = 0;
for(index_t group_id = 0; group_id < group_count; ++group_id)
{
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
const auto& k_batch = kargs.k_batch;
const auto block_start = cum_grid_size;
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
while(block_id < cum_grid_size)
{
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)
{
break; // exit the loop if all blocks are processed
}
}
}
}
};
} // namespace ck_tile

View File

@@ -23,8 +23,10 @@ template <bool kPadM_,
typename BLayout_,
typename CLayout_,
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_>
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_,
bool DoubleSmemBuffer_ = false,
bool UsePersistentKernel_ = false>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
@@ -33,7 +35,8 @@ struct TileGemmQuantTraits
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = 16;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_;
using BLayout = BLayout_;
@@ -44,6 +47,7 @@ struct TileGemmQuantTraits
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};