Merge branch 'develop' into lwpck-4181

This commit is contained in:
khushbu
2025-12-15 15:28:16 -05:00
166 changed files with 8195 additions and 2289 deletions

View File

@@ -72,7 +72,12 @@ inline bool is_xdl_supported()
is_gfx12_supported() || is_gfx11_supported();
}
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
template <typename ADataType,
typename BDataType,
index_t MPerXDL64,
index_t NPerXDL64,
index_t MPerXDL32 = MPerXDL64,
index_t NPerXDL32 = NPerXDL64>
inline bool is_xdl_wmma_supported()
{
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
@@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported()
}
else if(is_gfx12_supported() || is_gfx11_supported())
{
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16))
{
return false;
}

View File

@@ -8,10 +8,16 @@
#include <sstream>
#include <regex>
#include <optional>
#include <memory>
#include "ck/stream_config.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#endif
#endif
#include "ck/utility/get_id.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
@@ -91,6 +97,57 @@ static constexpr auto GetNXdlPerWave2()
IsWave64>(); \
}
template <index_t BlockSize_,
index_t MPerBlock_,
index_t NPerBlock_,
index_t MPerXDL_,
index_t NPerXDL_,
index_t MXdlPerWave_,
index_t CShuffleMXdlPerWavePerShuffle_,
index_t CShuffleNXdlPerWavePerShuffle_,
bool IsWave64>
static constexpr auto GetWarpTileConfig()
{
constexpr auto MXdlPerWave64 = MXdlPerWave_;
constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
constexpr auto NXdlPerWave =
IsWave64
? GetNXdlPerWave2<BlockSize_,
MPerBlock_,
NPerBlock_,
MPerXDL_,
NPerXDL_,
MXdlPerWave_,
true>()
: GetNXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
if constexpr(IsWave64 == false && NXdlPerWave != 0)
{
constexpr auto CShuffleNXdlPerWavePerShuffle32 =
NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
: CShuffleNXdlPerWavePerShuffle_;
static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
return Sequence<16,
16,
MXdlPerWave32,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle32,
CShuffleNXdlPerWavePerShuffle32>{};
}
else
{
return Sequence<MPerXDL_,
NPerXDL_,
MXdlPerWave64,
NXdlPerWave,
CShuffleMXdlPerWavePerShuffle_,
CShuffleNXdlPerWavePerShuffle_>{};
}
}
#define INVOKER_RUN_IMPL \
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
{ \
@@ -227,6 +284,12 @@ struct BaseOperator
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
#ifdef CK_EXPERIMENTAL_BUILDER
// Return a description object for this operator, or nullptr if not supported.
virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
#endif
virtual std::string GetInstanceString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }

View File

@@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
{
using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle;
GET_NXDL_PER_WAVE_IMPL
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto WarpTileConfig64 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
true>();
static constexpr auto WarpTileConfig32 = GetWarpTileConfig<BlockSize,
MPerBlock,
NPerBlock,
MPerXDL,
NPerXDL,
MXdlPerWave,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
false>();
static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3);
static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3);
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
// GridwiseGemm
template <index_t NXdlPerWave_>
template <typename WarpTileConfig>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
BDataType,
@@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave_,
WarpTileConfig::At(0),
WarpTileConfig::At(1),
WarpTileConfig::At(2),
WarpTileConfig::At(3),
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
WarpTileConfig::At(4),
WarpTileConfig::At(5),
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemm64 = GridwiseGemmBase<decltype(WarpTileConfig64)>;
using GridwiseGemm32 = GridwiseGemmBase<decltype(WarpTileConfig32)>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
@@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeDataType, ComputeDataType, MPerXDL, NPerXDL>())
if(!ck::is_xdl_wmma_supported<ComputeDataType,
ComputeDataType,
MPerXDL,
NPerXDL,
WarpTileConfig32.At(0),
WarpTileConfig32.At(1)>())
{
return false;
}
@@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< AK1 << ", "
<< BK1 << ", "
<< ABlockTransferSrcVectorDim << ", "

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#endif
@@ -1240,6 +1241,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#endif
@@ -1064,6 +1065,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#endif
@@ -2080,6 +2081,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -29,6 +29,7 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#endif
@@ -2103,6 +2104,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
#endif
@@ -1019,6 +1020,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
return std::make_unique<ck_tile::reflect::InstanceStringDescription>(GetInstanceString());
}
#endif
};

View File

@@ -25,6 +25,7 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/conv_describe.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#endif
@@ -1238,6 +1239,22 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
"for the given template parameters.");
return ck_tile::reflect::instance_string<DeviceOp>();
}
std::unique_ptr<ck_tile::reflect::Description> describe() const override
{
static_assert(
ck_tile::reflect::conv::HasConvTraits<DeviceOp>,
"ConvTraits specialization not found for this device operation. "
"If you modified the template parameters of this class, ensure that "
"the corresponding ConvTraits specialization in "
"ck_tile/builder/reflect/conv_traits.hpp is updated to match, or that "
"InstanceTraits in "
"ck_tile/builder/reflect/"
"instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
"provides all required members for ConvTraits to work.");
return std::make_unique<ck_tile::reflect::conv::ConvDescription>(
ck_tile::reflect::describe<DeviceOp>());
}
#endif
};

View File

@@ -620,7 +620,44 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool isWave64 = get_warp_size() == 64;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
// Validate stride requirements for SplitK (k_batch > 1)
// TODO: Enable splitK
if(a.k_batch > 1)
{
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(a.StrideC != a.N)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For RowMajor layout: StrideC must equal N."
<< " Got StrideC=" << a.StrideC << ", N=" << a.N << std::endl;
}
return false;
}
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
if(a.StrideC != a.M)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " SplitK (k_batch=" << a.k_batch
<< ") requires contiguous output stride."
<< " For ColumnMajor layout: StrideC must equal M."
<< " Got StrideC=" << a.StrideC << ", M=" << a.M << std::endl;
}
return false;
}
}
}
bool group_arg_valid = false;
if(isWave64)
{

View File

@@ -527,11 +527,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
}
else
{
#if defined(__gfx11__)
// TODO: remove this restriction
static_assert(ScaleBlockM >= MPerWmma,
"ScaleBlockM must be greater equal than MPerWmma");
#endif
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::

View File

@@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state
#else
static constexpr bool CK_TILE_ARCH_GFX1010 = false;
#endif
#if defined(__gfx1011__)
static constexpr bool CK_TILE_ARCH_GFX1011 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1011 = false;
#endif
#if defined(__gfx1012__)
static constexpr bool CK_TILE_ARCH_GFX1012 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1012 = false;
#endif
#if defined(__gfx1013__)
static constexpr bool CK_TILE_ARCH_GFX1013 = true;
#else
static constexpr bool CK_TILE_ARCH_GFX1013 = false;
#endif
#if defined(__gfx10_1_generic__)
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true;
#else
static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false;
#endif // __gfx10_1_generic__
#if defined(__gfx1030__)
static constexpr bool CK_TILE_ARCH_GFX1030 = true;
@@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \
amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \

View File

@@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
{
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfBound(m, n))
if(mask.IsOutOfSinkBound(m, n))
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}

View File

@@ -34,77 +34,80 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
AccDataType v_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
constexpr std::size_t kGroupK = QuantGroupSize::kK;
// ---- A loader: dequant A(m,k) into AccDataType ----
auto load_a = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
return ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
};
// ---- B loader: dequant B(k,n) into AccDataType ----
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
return (k & 1) ? fp32_val.hi : fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
return fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
return ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
};
// Apply group dequant scale
if((k + 1) % QuantGroupSize::kK == 0)
// ---- scale loader for a given K-group index ----
auto load_scale = [&](ck_tile::index_t k_group) -> float {
const ck_tile::index_t outer_dim = aquant ? (m / QuantGroupSize::kM) : k_group;
const ck_tile::index_t inner_dim = aquant ? k_group : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
return q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
return fp8_to_float_raw(q(outer_dim, inner_dim));
}
else // QDataType == bf8_t by static_assert above
{
return bf8_to_float_raw(q(outer_dim, inner_dim));
}
};
// ---- Loop over K by groups (full and tail) ----
for(std::size_t k_begin = 0; k_begin < K; k_begin += kGroupK)
{
const std::size_t k_end = std::min<std::size_t>(k_begin + kGroupK, K);
AccDataType v_block_acc = 0;
// unscaled accumulation within this K-group
for(std::size_t k = k_begin; k < k_end; ++k)
{
const AccDataType v_a = load_a(k);
const AccDataType v_b = load_b(k);
v_block_acc += v_a * v_b;
}
const ck_tile::index_t k_group = static_cast<ck_tile::index_t>(k_begin / kGroupK);
const float scale = load_scale(k_group);
v_acc += v_block_acc * scale;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));

View File

@@ -84,7 +84,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
@@ -94,10 +94,10 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -114,18 +114,24 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
return shuffle_b(t, GemmConfig{});
}
template <typename GemmConfig, typename T>
auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
{
@@ -145,22 +151,22 @@ auto bq_permuteN(const ck_tile::HostTensor<T>& t, index_t group_n)
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
@@ -177,17 +183,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
else
{
assert(is_wave32() == false);
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
gemmConfig.K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
}
template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
return shuffle_b_permuteN(t, GemmConfig{});
}
} // namespace ck_tile

View File

@@ -8,7 +8,8 @@
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
{
Atomic = 0u,
Reduction = 1u
Atomic = 0u,
Reduction = 1u,
TreeReduction = 2u
};
} // namespace ck_tile

View File

@@ -35,7 +35,8 @@ template <typename AsDataType_,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters check and aligment
*
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
*/
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
{
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
return shuffle_tile;
}
/**
* @brief Shuffle tile configuration parameters
*
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
if constexpr(elem_per_thread <= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
static_assert(elem_per_thread % GetVectorSizeC() == 0);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
kMPerBlock / (MPerXdl * MWave)),
1>();
}
else
{
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
return AlignShuffleTileWithSmem<1,
min(num_xdl_shuffles,
kNPerBlock / (NPerXdl * NWave))>();
}
}
}();

View File

@@ -86,21 +86,22 @@ struct GenericAttentionMask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
: GenericAttentionMask(0, 0, y_total_, x_total_)
: GenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -141,6 +142,44 @@ struct GenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -195,6 +234,30 @@ struct GenericAttentionMask
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
if constexpr(IsLocal)
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end;
}
else
{
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -237,7 +300,7 @@ struct GenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
: SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
sink(mask_coord.at(number<2>{})),
y_total(mask_coord.at(number<3>{})),
x_total(mask_coord.at(number<4>{}))
{
}
@@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask
}
}
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetSinkTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, 0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0;
if(x_start <= sink_seq_end && sink > 0)
return ck_tile::make_tuple(0, 0, x_end);
else
return ck_tile::make_tuple(sink_seq_end, x_start, x_end);
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
@@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask
ck_tile::min(origin_end, split_end));
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; // 128
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256
const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0;
const index_t start = ck_tile::max(origin_start, split_start);
const index_t end = ck_tile::min(origin_end, split_end);
const bool is_first_intersecting_split =
(split_start <= origin_start && split_end >= origin_start);
const bool sink_in_range = (sink_seq_end <= start);
const index_t sink_offset =
(is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0;
return ck_tile::make_tuple(sink_offset, start, end);
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
@@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
return i_x >= x_total;
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total)
return false;
else
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
@@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask
}
private:
index_t y, x;
index_t y, x, sink;
index_t y_total, x_total;
};
@@ -620,6 +751,7 @@ static constexpr bool is_generic_attention_mask_v = is_generic_attention_mask<Ma
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
@@ -637,7 +769,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
return ck_tile::make_tuple(y, x, sink_size, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t sink_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, sink_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total};
}
template <typename MaskType>
@@ -649,7 +795,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
left_size, right_size, 0, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total};
}
} // namespace ck_tile

View File

@@ -162,6 +162,17 @@ struct StandardAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
template <bool UseExp2 = false>
@@ -224,6 +235,17 @@ struct LogitsSoftCap
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
constexpr uint32_t CUSTOM_MASK = 1U;
@@ -297,6 +319,17 @@ struct ComposedAttention
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
template <typename Params>
__device__ __forceinline__ bool LogitsSinkMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx);
}
};
} // namespace ck_tile

View File

@@ -200,7 +200,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -356,6 +356,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -418,6 +419,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -497,6 +499,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
ck_tile::index_t batch_stride_v,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -557,6 +560,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -1008,6 +1012,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -58,6 +58,7 @@ struct FmhaFwdKernel
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -155,7 +156,7 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -335,6 +336,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -393,6 +395,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -481,6 +484,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -529,6 +533,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -580,6 +585,7 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
@@ -628,6 +634,7 @@ struct FmhaFwdKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
p_drop,
s_randval,
@@ -673,6 +680,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -732,6 +740,7 @@ struct FmhaFwdKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -817,6 +826,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -861,6 +871,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -908,6 +919,7 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
@@ -952,6 +964,7 @@ struct FmhaFwdKernel
nhead_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q,
p_drop,
@@ -1443,6 +1456,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
@@ -2182,6 +2196,7 @@ struct FmhaFwdKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::kHasSink;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
return MakeKargsImpl(q_ptr,
@@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_o,
window_size_left,
window_size_right,
sink_size,
mask_type);
}
@@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
@@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q)
{
@@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel
batch_stride_v,
window_size_left,
window_size_right,
sink_size,
mask_type,
min_seqlen_q);
}
@@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink;
static constexpr bool kMergeNumHeadGroupsSeqLenQ =
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
@@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
(kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
(kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel
struct MaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right, sink_size;
ck_tile::GenericAttentionMaskEnum mask_type;
};
@@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t sink_size,
ck_tile::index_t mask_type)
{
Kargs kargs{{q_ptr,
@@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel
{
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.sink_size = sink_size;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kDoFp8StaticQuant)
@@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.sink_size,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);

View File

@@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return o_acc;
}
}
// k_dram_block_window
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
@@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
// v_dram_window
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
@@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {kN0, 0}));
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
@@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto row, auto col) {
return mask.IsOutOfSinkBound(row, col);
});
}
else
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
}
@@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
// tail
{
block_sync_lds();

View File

@@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
// check early exit if no work to do
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
@@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
@@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
@@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
@@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
i_page_block_k, k_dram_block_window, {k_move_offset, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
@@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);

View File

@@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
else
{
auto [start, end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
@@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return physical_seqlen_k_start_;
}
}();
const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0)
? aligned_physical_seqlen_k_start
: 0;
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) +
num_sink_loop;
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
k_dram_block_window_lengths, {kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
const index_t bias_n_offset = [&]() {
if constexpr(kHasSink)
return kv_load_start;
else
return logical_seqlen_k_start -
(physical_seqlen_k_start - aligned_physical_seqlen_k_start);
}();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
{bias_origin.at(number<0>{}), bias_n_offset},
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_block_tile = load_tile(k_dram_window);
}
const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops);
const auto k_move_offset = [&]() {
if constexpr(kHasSink)
return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0;
else
return kN0;
}();
auto physical_next_block_id_k =
amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
i_page_block_k, k_dram_block_window, {kN0, 0}));
i_page_block_k, k_dram_block_window, {k_move_offset, 0}));
auto physical_next_block_id_v = amd_wave_read_first_lane(
v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
@@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
#endif
}
}
move_tile_window(bias_dram_window, {0, kN0});
move_tile_window(bias_dram_window, {0, k_move_offset});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
@@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
@@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask_func(row, col - kv_l2p_offset);
});
};
if constexpr(kHasSink)
{
apply_mask(
[&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); });
}
else
{
apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); });
}
}
}
@@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k);
physical_next_block_id_v =
amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}));
i_page_block_v = v_page_block_navigator.move_tile_window(
i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v);
// tail
{
block_sync_lds();

View File

@@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -114,6 +115,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
template <typename QDataType_,
@@ -167,6 +169,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
};
// extract tile size attributes to remove dependency on traits

View File

@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
@@ -233,10 +234,26 @@ struct BlockFmhaPipelineQRKSVS
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS
if constexpr(kHasDropout)
{
// K and dropout use the same address in LDS, finish loading from k_lds_window by
// gemm_0 to reuse LDS.
block_sync_lds();
auto randval_ptr = reinterpret_cast<char*>(smem_ptr);
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
block_sync_lds();
@@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS
});
}
// move K tile windows
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{

View File

@@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
@@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
if constexpr(kHasSink)
return mask.GetSinkTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
else
{
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(0, start, end);
}
}();
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
const auto num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
@@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{kv_load_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
@@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
randval_dram_block_window_tmp, kv_load_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
{0, kv_load_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
@@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync
#endif
}
}
if constexpr(kHasSink)
{
if(i_total_loops == 0)
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
@@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !variant.LogitsMask(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
if constexpr(kHasSink)
{
apply_mask([&](auto&&... args) {
return variant.LogitsSinkMask(std::forward<decltype(args)>(args)...);
});
}
else
{
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
}
@@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
index_t seq_offset = [&]() {
if constexpr(!kHasSink)
return seqlen_k_start + i_total_loops * kN0;
const bool in_sink_phase = (num_sink_loop > i_total_loops);
if(i_total_loops == num_sink_loop)
move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end});
return in_sink_phase ? (kv_load_start + i_total_loops * kN0)
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
}();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
const auto p = [&]() {
@@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
if constexpr(kHasSink)
{
if(i_total_loops == 0)
{
move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0});
move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end});
}
}
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&

View File

@@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasUnevenSplits = true;
static constexpr bool kHasSink = Problem::kHasSink;
static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||

View File

@@ -20,8 +20,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
BlockAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -36,6 +37,7 @@ struct TileFmhaTraits
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
@@ -65,8 +67,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
bool kIsPagedKV_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
bool kHasSink_ = false>
struct TileFmhaFwdPagedKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -81,6 +84,7 @@ struct TileFmhaFwdPagedKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
@@ -95,7 +99,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kIsPagedKV_,
bool kHasUnevenSplits_,
bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kHasSink_ = false>
struct TileFmhaFwdSplitKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -112,6 +117,7 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kHasSink = kHasSink_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,

View File

@@ -986,6 +986,8 @@ struct MoeSortingKernel
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
__syncthreads();
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights

View File

@@ -33,9 +33,10 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"

View File

@@ -232,7 +232,7 @@ struct BatchedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr1[GetSmemSize()];
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
UniversalGemmKernel::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},

View File

@@ -310,7 +310,7 @@ struct GroupedGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -561,6 +561,7 @@ struct GroupedGemmKernel
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_sync_lds();
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <typename CompilerTarget, typename Enabler = void>
struct StreamKCoherency
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::coherence_default;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX942,
core::arch::amdgcn_target_id::GFX950>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::SYSTEM_NT0;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX908,
core::arch::amdgcn_target_id::GFX90A>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::glc_slc;
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "streamk_gemm_coherency.hpp"
namespace ck_tile {
@@ -318,37 +319,58 @@ struct StreamKKernel
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
* CTA index.
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
* set by the given CTA index.
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates the
* output block tile with accumulated values.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
@@ -370,8 +392,8 @@ struct StreamKKernel
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for loading
* the partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
@@ -405,8 +427,8 @@ struct StreamKKernel
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing the
* partial block tile.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
@@ -420,7 +442,10 @@ struct StreamKKernel
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
@@ -431,8 +456,11 @@ struct StreamKKernel
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
@@ -483,7 +511,8 @@ struct StreamKKernel
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
@@ -528,46 +557,107 @@ struct StreamKKernel
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
if constexpr(TilePartitioner::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile =
kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta =
kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
for(index_t stride = 1;; stride <<= 1)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
while(accum_iters < iter_per_tile)
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C
// tensor.
if(tile_started && !partner_in_tile)
{
WaitStorePartialDone(kargs, next_cta);
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;
}
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
// It's this workgroup's turn to read from partials.
if(tile_local_cta_idx % (stride << 1) == 0)
{
// If this workgroup's partner is in the tile then it can read from
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
// workgroups, except the workgroup who starts the tile, will write to
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
{
static_assert(
"An implementation does not exist for the chosen reduction strategy.");
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
@@ -631,6 +721,7 @@ struct StreamKKernel
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
@@ -639,10 +730,10 @@ struct StreamKKernel
private:
/**
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
* the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
* of A and B.
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
* is the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the
* layouts of A and B.
* @note The default case is that A is assumed to be row major and B is assumed to be column
* major.
*/
@@ -679,15 +770,16 @@ struct StreamKKernel
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
}
/**
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
* kernel
* @return The occupancy
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
@@ -700,7 +792,7 @@ struct StreamKKernel
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return max(occupancy, 1);

View File

@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start iteration for the given the cta_idx.
* @param cta_idx The current Stream-K workgroup's index.
* @return index_t The start iteration.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
* @brief Calculates the workgroup's local CTA idx within the given tile.
*
* @param tile_iter_start The starting tile iteration.
* @param cta_idx The Stream-K workgroup index.
* @return index_t The tile local workgroup index in the tile.
*/
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
index_t cta_idx) const noexcept;
/**
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.

View File

@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
index_t cta_idx) const noexcept
{
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
// it is extra_iters.
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
iter = get_start_iter(cta_idx);
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
index_t tile_iter_start, index_t cta_idx) const noexcept
{
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
// Compute how many WGs fit before this tile starts assuming each WG does an
// extra_iter
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
// Compute how many WGs fit before this tile starts excluding extra iters
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
// Compute the CTA idx for the CTA that starts this tile
const index_t coop_group_start =
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
return cta_idx - coop_group_start;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();

View File

@@ -280,7 +280,7 @@ struct UniversalGemmKernel
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
// Run the GEMM
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&

View File

@@ -9,11 +9,35 @@
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV1
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
@@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }

View File

@@ -9,11 +9,34 @@
namespace ck_tile {
template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV2
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;

View File

@@ -43,4 +43,26 @@ struct TileGemmShape
}
};
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
constexpr index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
else
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
#endif
#endif
}
} // namespace ck_tile

View File

@@ -43,13 +43,14 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = false>
bool Preshuffle_ = false,
int VectorSize_ = 16>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using AsLayout = AsLayout_;

View File

@@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
@@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
using Base = BlockGemmBQuantBase<Problem_>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -1404,7 +1404,7 @@ struct QuantGemmKernel
assert(kargs.k_batch == 1);
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,

View File

@@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel
kQuantType == QuantType::BQuantGrouped)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
aq_ptr,

View File

@@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
// BDataType gets converted from PkInt4 during loading
using OverrideBDataType =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
std::conditional_t<std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>,
ADataType,
BDataType>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
@@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
@@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
}
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
CK_TILE_DEVICE void
BGlobalPrefetch(BBlockTile_& b_block_tile,
BDramWindow& b_copy_dram_window,
const BDramTileWindowStep& b_dram_tile_window_step) const
{
if constexpr(!std::is_same_v<BDataType, OverrideBDataType>)
{
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
}
else
{
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
}
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -262,7 +282,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<ADataType>(BBlockTileDistr{}));
decltype(make_static_distributed_tensor<OverrideBDataType>(BBlockTileDistr{}));
using BQBlockTile =
decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
@@ -292,8 +312,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
// B tile gets converted to A datatype during loading
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
@@ -314,7 +333,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
// B datatype is converted to A datatype during loading
auto b_shuffle_tmp = make_static_distributed_tensor<ADataType>(
auto b_shuffle_tmp = make_static_distributed_tensor<OverrideBDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
@@ -325,8 +344,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
@@ -369,8 +388,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
LoadAndConvertBTile(b_block_tile, b_copy_dram_window);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
// B tile gets converted to A datatype during loading
BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
bq_copy_dram_window,
bq_dram_tile_window_step);

View File

@@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))

View File

@@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&

View File

@@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation ==
memory_operation_enum::atomic_add &&
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&

View File

@@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync
if constexpr(num_reduce_warps == 1)
return;
block_sync_lds();
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)

View File

@@ -160,23 +160,23 @@ void dump_gemm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
inline void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -218,20 +218,20 @@ void dump_grouped_gemm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
inline void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -248,21 +248,22 @@ void dump_flatmm_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
inline void
dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -280,14 +281,14 @@ void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
inline void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -298,22 +299,22 @@ void dump_elementwise_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
inline void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -357,13 +358,13 @@ void dump_reduce_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
inline void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -373,19 +374,19 @@ void dump_permute_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
inline void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -401,20 +402,20 @@ void dump_topk_softmax_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
inline void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -431,19 +432,19 @@ void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_add_rmsnorm2d_rdquant_fwd_json(
const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
inline void
dump_add_rmsnorm2d_rdquant_fwd_json(const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -458,17 +459,17 @@ void dump_add_rmsnorm2d_rdquant_fwd_json(
END_JSON_DUMP_FILE();
}
void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
inline void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -482,19 +483,19 @@ void dump_smoothquant_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
inline void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -510,19 +511,19 @@ void dump_moe_sorting_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
inline void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -538,19 +539,19 @@ void dump_batched_transpose_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
inline void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -566,26 +567,26 @@ void dump_moe_smoothquant_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
inline void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -610,29 +611,29 @@ void dump_fused_moe_json(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
const std::string& qscale,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
inline void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
const std::string& qscale,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
@@ -658,33 +659,33 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
END_JSON_DUMP_FILE();
}
void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
inline void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);