mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Merge branch 'develop' into lwpck-4181
This commit is contained in:
@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
|
||||
is_gfx12_supported() || is_gfx11_supported();
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
index_t MPerXDL64,
|
||||
index_t NPerXDL64,
|
||||
index_t MPerXDL32 = MPerXDL64,
|
||||
index_t NPerXDL32 = NPerXDL64>
|
||||
inline bool is_xdl_wmma_supported()
|
||||
{
|
||||
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
|
||||
}
|
||||
else if(is_gfx12_supported() || is_gfx11_supported())
|
||||
{
|
||||
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
|
||||
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -8,10 +8,16 @@
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
#include <optional>
|
||||
#include <memory>
|
||||
|
||||
#include "ck/stream_config.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#endif
|
||||
#endif
|
||||
#include "ck/utility/get_id.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -91,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
|
||||
IsWave64>(); \
|
||||
}
|
||||
|
||||
template <index_t BlockSize_,
|
||||
index_t MPerBlock_,
|
||||
index_t NPerBlock_,
|
||||
index_t MPerXDL_,
|
||||
index_t NPerXDL_,
|
||||
index_t MXdlPerWave_,
|
||||
index_t CShuffleMXdlPerWavePerShuffle_,
|
||||
index_t CShuffleNXdlPerWavePerShuffle_,
|
||||
bool IsWave64>
|
||||
static constexpr auto GetWarpTileConfig()
|
||||
{
|
||||
constexpr auto MXdlPerWave64 = MXdlPerWave_;
|
||||
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
|
||||
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
|
||||
|
||||
constexpr auto NXdlPerWave =
|
||||
IsWave64
|
||||
? GetNXdlPerWave2<BlockSize_,
|
||||
MPerBlock_,
|
||||
NPerBlock_,
|
||||
MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave_,
|
||||
true>()
|
||||
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
|
||||
|
||||
if constexpr(IsWave64 == false && NXdlPerWave != 0)
|
||||
{
|
||||
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
|
||||
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
|
||||
: CShuffleNXdlPerWavePerShuffle_;
|
||||
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
|
||||
return Sequence<16,
|
||||
16,
|
||||
MXdlPerWave32,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle32,
|
||||
CShuffleNXdlPerWavePerShuffle32>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<MPerXDL_,
|
||||
NPerXDL_,
|
||||
MXdlPerWave64,
|
||||
NXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle_,
|
||||
CShuffleNXdlPerWavePerShuffle_>{};
|
||||
}
|
||||
}
|
||||
|
||||
#define INVOKER_RUN_IMPL \
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
|
||||
{ \
|
||||
@@ -227,6 +284,12 @@ struct BaseOperator
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
// Return a description object for this operator, or nullptr if not supported.
|
||||
virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
|
||||
#endif
|
||||
|
||||
virtual std::string GetInstanceString() const { return ""; }
|
||||
|
||||
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
|
||||
|
||||
@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
|
||||
|
||||
GET_NXDL_PER_WAVE_IMPL
|
||||
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
|
||||
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
true>();
|
||||
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
false>();
|
||||
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
|
||||
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
// GridwiseGemm
|
||||
template <index_t NXdlPerWave_>
|
||||
template <typename WarpTileConfig>
|
||||
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave_,
|
||||
WarpTileConfig::At(0),
|
||||
WarpTileConfig::At(1),
|
||||
WarpTileConfig::At(2),
|
||||
WarpTileConfig::At(3),
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
WarpTileConfig::At(4),
|
||||
WarpTileConfig::At(5),
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
|
||||
if(!ck::is_xdl_wmma_supported<ComputeDataType,
|
||||
ComputeDataType,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
WarpTileConfig32.At(0),
|
||||
WarpTileConfig32.At(1)>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< ABlockTransferSrcVectorDim << ", "
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
static_assert(ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
|
||||
"ConvTraits specialization not found for this device operation. "
|
||||
"If you modified the template parameters of this class, ensure that "
|
||||
"the corresponding ConvTraits specialization in "
|
||||
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
|
||||
"InstanceTraits in "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
|
||||
"provides all required members for ConvTraits to work.");
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#endif
|
||||
|
||||
@@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/description.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/conv_describe.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
@@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<DeviceOp>();
|
||||
}
|
||||
|
||||
std::unique_ptr<ck_tile::reflect::Description> describe() const override
|
||||
{
|
||||
static_assert(
|
||||
ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
|
||||
"ConvTraits specialization not found for this device operation. "
|
||||
"If you modified the template parameters of this class, ensure that "
|
||||
"the corresponding ConvTraits specialization in "
|
||||
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
|
||||
"InstanceTraits in "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
|
||||
"provides all required members for ConvTraits to work.");
|
||||
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
|
||||
ck_tile::reflect::describe<DeviceOp>());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool isWave64 = get_warp_size() == 64;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
|
||||
// Validate stride requirements for SplitK (k_batch > 1)
|
||||
// TODO: Enable splitK
|
||||
if(a.k_batch > 1)
|
||||
{
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(a.StrideC != a.N)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For RowMajor layout: StrideC must equal N."
|
||||
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if(a.StrideC != a.M)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " SplitK (k_batch=" << a.k_batch
|
||||
<< ") requires contiguous output stride."
|
||||
<< " For ColumnMajor layout: StrideC must equal M."
|
||||
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool group_arg_valid = false;
|
||||
if(isWave64)
|
||||
{
|
||||
|
||||
@@ -527,11 +527,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
// TODO: remove this restriction
|
||||
static_assert(ScaleBlockM >= MPerWmma,
|
||||
"ScaleBlockM must be greater equal than MPerWmma");
|
||||
#endif
|
||||
static_assert(
|
||||
ScaleBlockK >=
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
|
||||
|
||||
@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
|
||||
#endif
|
||||
#if defined(__gfx1011__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
|
||||
#endif
|
||||
#if defined(__gfx1012__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
|
||||
#endif
|
||||
#if defined(__gfx1013__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
|
||||
#endif
|
||||
#if defined(__gfx10_1_generic__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
|
||||
#else
|
||||
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
|
||||
#endif // __gfx10_1_generic__
|
||||
|
||||
#if defined(__gfx1030__)
|
||||
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
|
||||
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
|
||||
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \
|
||||
|
||||
@@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
if(mask.IsOutOfBound(m, n))
|
||||
if(mask.IsOutOfSinkBound(m, n))
|
||||
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,77 +34,80 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0, v_block_acc = 0;
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, pk_int4_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> ||
|
||||
std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
constexpr std::size_t kGroupK = QuantGroupSize::kK;
|
||||
|
||||
// ---- A loader: dequant A(m,k) into AccDataType ----
|
||||
auto load_a = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
};
|
||||
|
||||
// ---- B loader: dequant B(k,n) into AccDataType ----
|
||||
auto load_b = [&](std::size_t k) -> AccDataType {
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
return (k & 1) ? fp32_val.hi : fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_block_acc += v_a * v_b;
|
||||
};
|
||||
|
||||
// Apply group dequant scale
|
||||
if((k + 1) % QuantGroupSize::kK == 0)
|
||||
// ---- scale loader for a given K-group index ----
|
||||
auto load_scale = [&](ck_tile::index_t k_group) -> float {
|
||||
const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
|
||||
const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, float>)
|
||||
{
|
||||
float scale = 0.f;
|
||||
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
|
||||
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
|
||||
if constexpr(std::is_same_v<QDataType, float>)
|
||||
{
|
||||
scale = q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected Q datatype.");
|
||||
}
|
||||
v_block_acc *= scale;
|
||||
v_acc += v_block_acc;
|
||||
v_block_acc = 0;
|
||||
return q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
return fp8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else // QDataType == bf8_t by static_assert above
|
||||
{
|
||||
return bf8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
};
|
||||
|
||||
// ---- Loop over K by groups (full and tail) ----
|
||||
for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
|
||||
{
|
||||
const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
|
||||
|
||||
AccDataType v_block_acc = 0;
|
||||
|
||||
// unscaled accumulation within this K-group
|
||||
for(std::size_t k = k_begin; k < k_end; ++k)
|
||||
{
|
||||
const AccDataType v_a = load_a(k);
|
||||
const AccDataType v_b = load_b(k);
|
||||
v_block_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
|
||||
const float scale = load_scale(k_group);
|
||||
|
||||
v_acc += v_block_acc * scale;
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
|
||||
@@ -84,7 +84,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -94,10 +94,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -114,18 +114,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b(t, GemmConfig{});
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
{
|
||||
@@ -145,22 +151,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
@@ -177,17 +183,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
return shuffle_b_permuteN(t, GemmConfig{});
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
namespace ck_tile {
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
Atomic = 0u,
|
||||
Reduction = 1u,
|
||||
TreeReduction = 2u
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -35,7 +35,8 @@ template <typename AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters check and aligment
|
||||
*
|
||||
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
|
||||
*/
|
||||
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
|
||||
{
|
||||
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
|
||||
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
|
||||
|
||||
constexpr auto shuffle_tile =
|
||||
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
? std::make_tuple(1, 1)
|
||||
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
|
||||
|
||||
return shuffle_tile;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
if constexpr(elem_per_thread <= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
|
||||
static_assert(elem_per_thread % GetVectorSizeC() == 0);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
|
||||
kMPerBlock / (MPerXdl * MWave)),
|
||||
1>();
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
return AlignShuffleTileWithSmem<1,
|
||||
min(num_xdl_shuffles,
|
||||
kNPerBlock / (NPerXdl * NWave))>();
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -86,21 +86,22 @@ struct GenericAttentionMask
|
||||
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: GenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
sink(mask_coord.at(number<2>{})),
|
||||
y_total(mask_coord.at(number<3>{})),
|
||||
x_total(mask_coord.at(number<4>{}))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -141,6 +142,44 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, 0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
||||
if(x_start <= sink_seq_end && sink > 0)
|
||||
return ck_tile::make_tuple(0, 0, x_end);
|
||||
else
|
||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
@@ -195,6 +234,30 @@ struct GenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
return i_x >= x_total;
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1;
|
||||
index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
else
|
||||
{
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
@@ -237,7 +300,7 @@ struct GenericAttentionMask
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
@@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask
|
||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
SimplifiedGenericAttentionMask(
|
||||
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
sink(mask_coord.at(number<2>{})),
|
||||
y_total(mask_coord.at(number<3>{})),
|
||||
x_total(mask_coord.at(number<4>{}))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, 0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
|
||||
|
||||
if(x_start <= sink_seq_end && sink > 0)
|
||||
return ck_tile::make_tuple(0, 0, x_end);
|
||||
else
|
||||
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
|
||||
number<TileHeight> height,
|
||||
@@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask
|
||||
ck_tile::min(origin_end, split_end));
|
||||
}
|
||||
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
|
||||
number<TileHeight> height,
|
||||
number<TileWidth> width,
|
||||
index_t num_splits,
|
||||
index_t i_split) const
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split; // 128
|
||||
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
|
||||
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
|
||||
const index_t start = ck_tile::max(origin_start, split_start);
|
||||
const index_t end = ck_tile::min(origin_end, split_end);
|
||||
const bool is_first_intersecting_split =
|
||||
(split_start <= origin_start && split_end >= origin_start);
|
||||
const bool sink_in_range = (sink_seq_end <= start);
|
||||
|
||||
const index_t sink_offset =
|
||||
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
|
||||
return ck_tile::make_tuple(sink_offset, start, end);
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
@@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
return i_x >= x_total;
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
|
||||
return false;
|
||||
else
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
@@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y, x, sink;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
@@ -620,6 +751,7 @@ static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Ma
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t sink_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
@@ -637,7 +769,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t sink_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, sink_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
@@ -649,7 +795,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||
left_size, right_size, 0, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -162,6 +162,17 @@ struct StandardAttention
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool UseExp2 = false>
|
||||
@@ -224,6 +235,17 @@ struct LogitsSoftCap
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr uint32_t CUSTOM_MASK = 1U;
|
||||
@@ -297,6 +319,17 @@ struct ComposedAttention
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -200,7 +200,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -356,6 +356,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -418,6 +419,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -497,6 +499,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -557,6 +560,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -1008,6 +1012,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -58,6 +58,7 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -155,7 +156,7 @@ struct FmhaFwdKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -335,6 +336,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -393,6 +395,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -481,6 +484,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -529,6 +533,7 @@ struct FmhaFwdKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
@@ -580,6 +585,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
@@ -628,6 +634,7 @@ struct FmhaFwdKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
p_drop,
|
||||
s_randval,
|
||||
@@ -673,6 +680,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -732,6 +740,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -817,6 +826,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -861,6 +871,7 @@ struct FmhaFwdKernel
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
@@ -908,6 +919,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
@@ -952,6 +964,7 @@ struct FmhaFwdKernel
|
||||
nhead_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
@@ -1443,6 +1456,7 @@ struct FmhaFwdKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
@@ -2182,6 +2196,7 @@ struct FmhaFwdKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
@@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel
|
||||
batch_stride_o,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type);
|
||||
}
|
||||
|
||||
@@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
{
|
||||
@@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
@@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
{
|
||||
@@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel
|
||||
batch_stride_v,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q);
|
||||
}
|
||||
@@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
|
||||
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
@@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
|
||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
|
||||
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel
|
||||
struct MaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right, sink_size;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
@@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
@@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
@@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.sink_size = sink_size;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
@@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.sink_size,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
// k_dram_block_window
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
@@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
// v_dram_window
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
// prefetch K tile
|
||||
@@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
auto physical_next_block_id_k =
|
||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
@@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
{
|
||||
const auto k_origin = k_page_block_navigator.to_global_window_origin(
|
||||
@@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto row, auto col) {
|
||||
return mask.IsOutOfSinkBound(row, col);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
||||
physical_next_block_id_v =
|
||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
else
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
// check early exit if no work to do
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t logical_num_total_loop =
|
||||
@@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// store Q into LDS
|
||||
@@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
clear_tile(s_acc); // initialize C
|
||||
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
// load the second tile of the first iteration
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
|
||||
@@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
@@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
{
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0});
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
|
||||
|
||||
k_dram_window = make_tile_window(
|
||||
k_dram_block_window,
|
||||
@@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
// moving k_dram_window is an in-page-block operation, so there is
|
||||
// no need to invoke k_page_block_navigator.move_tile_window() here.
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
}
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
@@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
|
||||
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
else
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
@@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return physical_seqlen_k_start_;
|
||||
}
|
||||
}();
|
||||
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
|
||||
? aligned_physical_seqlen_k_start
|
||||
: 0;
|
||||
const index_t num_total_loop =
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
|
||||
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
|
||||
num_sink_loop;
|
||||
|
||||
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
|
||||
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
|
||||
k_dram_block_window_lengths, {kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
const index_t bias_n_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return kv_load_start;
|
||||
else
|
||||
return logical_seqlen_k_start -
|
||||
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
|
||||
}();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}),
|
||||
logical_seqlen_k_start - (physical_seqlen_k_start -
|
||||
aligned_physical_seqlen_k_start)}, // M/N
|
||||
{bias_origin.at(number<0>{}), bias_n_offset},
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
|
||||
|
||||
const auto k_move_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
|
||||
auto physical_next_block_id_k =
|
||||
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}));
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
|
||||
auto physical_next_block_id_v = amd_wave_read_first_lane(
|
||||
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
|
||||
|
||||
@@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
move_tile_window(bias_dram_window, {0, k_move_offset});
|
||||
|
||||
/// TODO: only check in first/last iteration without increasing code size
|
||||
if constexpr(kHasUnevenSplits)
|
||||
@@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
physical_seqlen_k_start_ = physical_seqlen_k_start,
|
||||
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
|
||||
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col - kv_l2p_offset);
|
||||
});
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask_func(row, col - kv_l2p_offset);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask(
|
||||
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
// move K tile windows
|
||||
i_page_block_k = k_page_block_navigator.move_tile_window(
|
||||
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
|
||||
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
|
||||
physical_next_block_id_v =
|
||||
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
|
||||
i_page_block_v = v_page_block_navigator.move_tile_window(
|
||||
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
@@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
@@ -114,6 +115,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
@@ -167,6 +169,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
|
||||
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kHasSink = Traits::kHasSink;
|
||||
};
|
||||
|
||||
// extract tile size attributes to remove dependency on traits
|
||||
|
||||
@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
@@ -233,10 +234,26 @@ struct BlockFmhaPipelineQRKSVS
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
@@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// K and dropout use the same address in LDS, finish loading from k_lds_window by
|
||||
// gemm_0 to reuse LDS.
|
||||
block_sync_lds();
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
|
||||
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(!kHasSink)
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
|
||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
@@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
}
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
|
||||
@@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
@@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
clear_tile(l);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
@@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
@@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
auto randval_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(!kHasSink)
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
|
||||
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
@@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
// move K tile windows
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == 0)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
|
||||
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
}
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
|
||||
@@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasUnevenSplits = true;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
|
||||
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
|
||||
|
||||
@@ -20,8 +20,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -36,6 +37,7 @@ struct TileFmhaTraits
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
@@ -65,8 +67,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
|
||||
bool kIsPagedKV_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaFwdPagedKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -81,6 +84,7 @@ struct TileFmhaFwdPagedKVTraits
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
@@ -95,7 +99,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kIsPagedKV_,
|
||||
bool kHasUnevenSplits_,
|
||||
bool kMergeNumHeadGroupsSeqLenQ_ = false,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kHasSink_ = false>
|
||||
struct TileFmhaFwdSplitKVTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -112,6 +117,7 @@ struct TileFmhaFwdSplitKVTraits
|
||||
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
|
||||
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
@@ -986,6 +986,8 @@ struct MoeSortingKernel
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
smem_cumdup(num_experts) = smem_cumsum(num_experts);
|
||||
|
||||
// fill the p_sorted_token_ids/p_sorted_weights
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -232,7 +232,7 @@ struct BatchedGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr1[GetSmemSize()];
|
||||
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
|
||||
UniversalGemmKernel::RunGemm2LDS({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
|
||||
@@ -310,7 +310,7 @@ struct GroupedGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
@@ -561,6 +561,7 @@ struct GroupedGemmKernel
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
|
||||
block_sync_lds();
|
||||
block_id = block_id + grid_size; // advance to next block
|
||||
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
|
||||
if(block_id >= cum_grid_size)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CompilerTarget, typename Enabler = void>
|
||||
struct StreamKCoherency
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::coherence_default;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX942,
|
||||
core::arch::amdgcn_target_id::GFX950>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::SYSTEM_NT0;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX908,
|
||||
core::arch::amdgcn_target_id::GFX90A>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::glc_slc;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "streamk_gemm_coherency.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -318,37 +319,58 @@ struct StreamKKernel
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
* @note This function utilizes a scalar store to write to the flags buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the approproriate
|
||||
// cache level(s) to ensure the write is visible to other workgroups. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_store_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
||||
:
|
||||
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
* @note This function utilizes a scalar load to read from the flags
|
||||
* buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t result;
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
do
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the
|
||||
// approproriate cache level(s) to avoid reading stale flags. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_load_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
||||
: "=s"(result)
|
||||
: "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
} while(result != 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates
|
||||
* the output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
@@ -370,8 +392,8 @@ struct StreamKKernel
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for
|
||||
* loading the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
@@ -405,8 +427,8 @@ struct StreamKKernel
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
@@ -420,7 +442,10 @@ struct StreamKKernel
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<
|
||||
address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
@@ -431,8 +456,11 @@ struct StreamKKernel
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -483,7 +511,8 @@ struct StreamKKernel
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
@@ -528,46 +557,107 @@ struct StreamKKernel
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile =
|
||||
kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta =
|
||||
kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
}
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
// It's this workgroup's turn to read from partials.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
// If this workgroup's partner is in the tile then it can read from
|
||||
// partials and accumulate results.
|
||||
if(partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs,
|
||||
partner_cta_idx,
|
||||
c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// Otherwise, it's this workgroup's turn to write to partials. All
|
||||
// workgroups, except the workgroup who starts the tile, will write to
|
||||
// partials.
|
||||
else
|
||||
{
|
||||
StorePartial(kargs, cta_idx, accum_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
// Once the workgroup writes to partials, it has no more work to do for
|
||||
// this tile.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
"An implementation does not exist for the chosen reduction strategy.");
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
@@ -631,6 +721,7 @@ struct StreamKKernel
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Stream-K section
|
||||
@@ -639,10 +730,10 @@ struct StreamKKernel
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
|
||||
* is the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the
|
||||
* layouts of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
@@ -679,15 +770,16 @@ struct StreamKKernel
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
ck_tile::hip_check_error(hipGetDevice(&dev));
|
||||
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
|
||||
* kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
@@ -700,7 +792,7 @@ struct StreamKKernel
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
ck_tile::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return max(occupancy, 1);
|
||||
@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start iteration for the given the cta_idx.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @return index_t The start iteration.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
* @brief Calculates the workgroup's local CTA idx within the given tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param cta_idx The Stream-K workgroup index.
|
||||
* @return index_t The tile local workgroup index in the tile.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
|
||||
index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
|
||||
index_t cta_idx) const noexcept
|
||||
{
|
||||
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
|
||||
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
|
||||
// it is extra_iters.
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter = get_start_iter(cta_idx);
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
|
||||
index_t tile_iter_start, index_t cta_idx) const noexcept
|
||||
{
|
||||
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
|
||||
|
||||
// Compute how many WGs fit before this tile starts assuming each WG does an
|
||||
// extra_iter
|
||||
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
|
||||
// Compute how many WGs fit before this tile starts excluding extra iters
|
||||
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
|
||||
// Compute the CTA idx for the CTA that starts this tile
|
||||
const index_t coop_group_start =
|
||||
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
|
||||
return cta_idx - coop_group_start;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
@@ -280,7 +280,7 @@ struct UniversalGemmKernel
|
||||
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto kernel = kentry<1, Kernel, KernelArgs>;
|
||||
int occupancy;
|
||||
hip_check_error(
|
||||
ck_tile::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
|
||||
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
|
||||
@@ -9,11 +9,35 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
@@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
@@ -9,11 +9,34 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
|
||||
@@ -43,4 +43,26 @@ struct TileGemmShape
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
|
||||
constexpr index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
|
||||
else
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -43,13 +43,14 @@ template <bool kPadM_,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = false>
|
||||
bool Preshuffle_ = false,
|
||||
int VectorSize_ = 16>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
|
||||
@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
@@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using OverrideBDataType = std::conditional_t<
|
||||
std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
using Base = BlockGemmBQuantBase<Problem_>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
@@ -1404,7 +1404,7 @@ struct QuantGemmKernel
|
||||
assert(kargs.k_batch == 1);
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
|
||||
@@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
|
||||
@@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
// BDataType gets converted from PkInt4 during loading
|
||||
using OverrideBDataType =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t> &&
|
||||
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
using I0 = number<0>;
|
||||
@@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
@@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
}
|
||||
|
||||
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
|
||||
CK_TILE_DEVICE void
|
||||
BGlobalPrefetch(BBlockTile_& b_block_tile,
|
||||
BDramWindow& b_copy_dram_window,
|
||||
const BDramTileWindowStep& b_dram_tile_window_step) const
|
||||
{
|
||||
if constexpr(!std::is_same_v<BDataType, OverrideBDataType>)
|
||||
{
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
@@ -262,7 +282,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
|
||||
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
|
||||
using BQBlockTile =
|
||||
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
|
||||
|
||||
@@ -292,8 +312,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
|
||||
|
||||
@@ -314,7 +333,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
// B datatype is converted to A datatype during loading
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
@@ -325,8 +344,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -369,8 +388,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
// B tile gets converted to A datatype during loading
|
||||
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
|
||||
bq_copy_dram_window,
|
||||
bq_dram_tile_window_step);
|
||||
|
||||
@@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
|
||||
@@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
@@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
@@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
block_sync_lds();
|
||||
// Each warp's lane 0 writes its partial results to shared memory
|
||||
const index_t smem_offset = warp_id;
|
||||
if(lane_id == 0)
|
||||
|
||||
@@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
inline void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
inline void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
inline void
|
||||
dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
inline void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
inline void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
inline void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
inline void
|
||||
dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
inline void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
inline void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
inline void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
inline void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
inline void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
const std::string& qscale,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
inline void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
const std::string& qscale,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
@@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
inline void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
|
||||
Reference in New Issue
Block a user