mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)
[CK_TILE] Add SageAttention v2 forward kernel with multi-granularity quantization (#6574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Add a CK_TILE forward kernel implementing [SageAttention v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that applies multi-granularity quantization to Q/K/V before computing attention, trading minimal accuracy loss for higher throughput on low-precision hardware. ### Quantization design | Tensor | Supported data types | Scale granularity options | |--------|---------------------|--------------------------| | Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (32 tokens), per-thread (4 tokens) | | K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp (64 tokens), per-thread (16 tokens) | | V | fp8 | per-channel (always) | | O | bf16 | — | Three precision combinations are supported: `fp8/bf16` (QKV fp8, O bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK int4, V fp8, O bf16). ### Architecture support - **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set - **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for fp8-family dtypes) ### Implementation - Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC` (async copy) - Masking support: no mask, causal (top-left / bottom-right), and generic windowed - Batch and group (variable-length) modes - Head dimension: d=128, d_v=128 - Python codegen under `example/ck_tile/49_sageattention/codegen/` generates kernel instances per target/dtype/tile combination - Smoke tests included via `tile_example_sageattn_fwd` ### Test commands \`\`\`bash # fp8 QKV ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=fp8bf16 -qscale=3 -init=3 # int8 QK, fp8 V ./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128 -kname=1 -prec=i8fp8bf16 -qscale=3 -init=3 \`\`\` \`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e8d64ad5c6
commit
de0a61e5c2
@@ -530,4 +530,10 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockSageAttentionQuantScaleEnum
|
||||
{
|
||||
NO_SCALE = 0,
|
||||
PERTENSOR = 1,
|
||||
BLOCKSCALE = 2,
|
||||
PERWARP = 3,
|
||||
PERTHREAD = 4,
|
||||
};
|
||||
|
||||
template <BlockSageAttentionQuantScaleEnum>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::NO_SCALE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTENSOR>
|
||||
{
|
||||
static constexpr const char* name = "pertensor";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::BLOCKSCALE>
|
||||
{
|
||||
static constexpr const char* name = "blockscale";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERWARP>
|
||||
{
|
||||
static constexpr const char* name = "perwarp";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTHREAD>
|
||||
{
|
||||
static constexpr const char* name = "perthread";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
1026
include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp
Normal file
1026
include/ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockSageAttnPipelineEnum
|
||||
{
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
};
|
||||
|
||||
template <BlockSageAttnPipelineEnum>
|
||||
struct BlockSageAttnPipelineEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qr";
|
||||
};
|
||||
template <>
|
||||
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS_ASYNC>
|
||||
{
|
||||
static constexpr const char* name = "qr_async";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockSageAttnShape_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename AttnMask_,
|
||||
typename Traits_>
|
||||
struct BlockSageAttnPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockSageAttnShape = remove_cvref_t<BlockSageAttnShape_>;
|
||||
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
|
||||
using AttnMask = remove_cvref_t<AttnMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = BlockSageAttnShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockSageAttnShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockSageAttnShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
/// Must match host scale tensor layout (same values as TileSageAttnTraits for Sage kernels).
|
||||
static constexpr index_t kBlockScaleSizeQ = Traits::kBlockScaleSizeQ;
|
||||
static constexpr index_t kBlockScaleSizeK = Traits::kBlockScaleSizeK;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,861 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSDefaultPolicy>
|
||||
struct BlockSageAttentionPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using QGemmDataType = SageAttnQKGemmQDataType<Problem>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
|
||||
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
|
||||
static_assert(std::is_same_v<PDataType, VDataType>,
|
||||
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<PDataType, fp8_t>,
|
||||
"SageAttention pipeline requires PDataType = fp8_t");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<VDataType, fp8_t>,
|
||||
"SageAttention pipeline requires VDataType = fp8_t");
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
|
||||
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
|
||||
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
|
||||
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
|
||||
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
|
||||
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
|
||||
|
||||
// FP8 softmax shift constants to map softmax output into representable FP8 range
|
||||
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
|
||||
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
|
||||
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
|
||||
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
AttnMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// K tile in LDS
|
||||
KLdsDataType* k_lds_ptr = static_cast<KLdsDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window_reg =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
auto q = load_tile(q_dram_window_reg);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType =
|
||||
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
|
||||
SaccBlockTileType,
|
||||
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(start, end);
|
||||
}();
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
// Note: here occ are all cleard, return it
|
||||
// Note: q loaded but no fence, ignore it.
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, kv_load_start},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = [&]() {
|
||||
if constexpr(std::is_same_v<QDataType, QGemmDataType>)
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
else
|
||||
{
|
||||
auto q_tile_tmp = make_static_distributed_tensor<QGemmDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
constexpr index_t kPackedSize = numeric_traits<QDataType>::PackedSize;
|
||||
constexpr index_t kUnaryOpSize = 8;
|
||||
static_assert(std::is_same_v<QDataType, ck_tile::pk_int4_t>);
|
||||
static_assert(kPackedSize == 2);
|
||||
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() ==
|
||||
decltype(q)::get_thread_buffer_size() * kPackedSize);
|
||||
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
|
||||
|
||||
using RawQType = typename QDataType::type;
|
||||
using SrcVectorType = ext_vector_t<RawQType, kUnaryOpSize / kPackedSize>;
|
||||
using DstVectorType = ext_vector_t<QGemmDataType, kUnaryOpSize>;
|
||||
constexpr index_t kVecSize =
|
||||
decltype(q_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
|
||||
static_assert(decltype(q)::get_thread_buffer_size() ==
|
||||
kVecSize * (kUnaryOpSize / kPackedSize));
|
||||
|
||||
const element_wise::PassThroughPack8 pass_through_pack8{};
|
||||
static_for<0, kVecSize, 1>{}([&](auto i) {
|
||||
pass_through_pack8(
|
||||
q_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
q.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
return q_tile_tmp;
|
||||
}
|
||||
}();
|
||||
|
||||
// prefetch K tile
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm0 = [] {
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
|
||||
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
|
||||
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
|
||||
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
|
||||
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
|
||||
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
|
||||
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
|
||||
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
|
||||
{
|
||||
static_assert(NumMfmaInsts % 8 == 0);
|
||||
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(get_warp_size() % kGemm0MPerWarp == 0);
|
||||
constexpr index_t kWarpSz = get_warp_size();
|
||||
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
|
||||
// indexing)
|
||||
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
constexpr index_t kNumKScalesPW =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
|
||||
? kN0 / Problem::kBlockScaleSizeK
|
||||
: 1;
|
||||
constexpr index_t kNumKScalesPT =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
|
||||
? kN0 / Problem::kBlockScaleSizeK / 2
|
||||
: 1;
|
||||
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
|
||||
}
|
||||
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
|
||||
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
|
||||
}
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
auto s_acc_gemm = SaccBlockTileType{};
|
||||
const auto store_k_block_tile_to_lds = [&](const auto& k_block_tile_) {
|
||||
if constexpr(std::is_same_v<KDataType, KLdsDataType>)
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile_));
|
||||
else
|
||||
{
|
||||
auto k_block_tile_tmp = make_static_distributed_tensor<KLdsDataType>(
|
||||
k_dram_window.get_tile_distribution());
|
||||
using KBlockTileType = remove_cvref_t<decltype(k_block_tile_)>;
|
||||
constexpr index_t kPackedSize = numeric_traits<KDataType>::PackedSize;
|
||||
constexpr index_t kUnaryOpSize = 8;
|
||||
static_assert(std::is_same_v<KDataType, ck_tile::pk_int4_t>);
|
||||
static_assert(kPackedSize == 2);
|
||||
static_assert(decltype(k_block_tile_tmp)::get_thread_buffer_size() ==
|
||||
KBlockTileType::get_thread_buffer_size() * kPackedSize);
|
||||
static_assert(
|
||||
decltype(k_block_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
|
||||
|
||||
using RawKType = typename KDataType::type;
|
||||
using SrcVectorType = ext_vector_t<RawKType, kUnaryOpSize / kPackedSize>;
|
||||
using DstVectorType = ext_vector_t<KLdsDataType, kUnaryOpSize>;
|
||||
constexpr index_t kVecSize =
|
||||
decltype(k_block_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
|
||||
static_assert(KBlockTileType::get_thread_buffer_size() ==
|
||||
kVecSize * (kUnaryOpSize / kPackedSize));
|
||||
|
||||
const element_wise::PassThroughPack8 pass_through_pack8{};
|
||||
static_for<0, kVecSize, 1>{}([&](auto i) {
|
||||
pass_through_pack8(
|
||||
k_block_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(
|
||||
i),
|
||||
k_block_tile_.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
store_tile(k_lds_window, k_block_tile_tmp);
|
||||
}
|
||||
};
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
clear_tile(s_acc_gemm); // initialize C
|
||||
store_k_block_tile_to_lds(k_block_tile);
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_k_block_tile_to_lds(k_block_tile); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 2) * kK0>{},
|
||||
sequence<kM0, (k0_loops - 1) * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
block_sync_lds();
|
||||
|
||||
store_k_block_tile_to_lds(k_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_window);
|
||||
schedule_gemm0();
|
||||
}
|
||||
|
||||
// Convert GEMM output to SaccDataType for softmax (if needed)
|
||||
auto s_acc = [&]() {
|
||||
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
|
||||
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
|
||||
{
|
||||
return s_acc_gemm; // No conversion needed (e.g., float -> float)
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// PERTHREAD: kBlockScaleSizeK=16
|
||||
// The s_acc tile distribution is determined by
|
||||
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
|
||||
// each thread processes exactly 16 consecutive elements in the K dimension. This
|
||||
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
|
||||
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
|
||||
// elements to K scale indices.
|
||||
static_assert(Problem::kBlockScaleSizeK == 16,
|
||||
"PERTHREAD: kBlockScaleSizeK must be 16");
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures the distribution has 16 consecutive K elements per
|
||||
// thread
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERTHREAD requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"16 consecutive K elements");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPT] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
|
||||
const index_t scale_idx = col_offset >> 4;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
|
||||
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
|
||||
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
|
||||
// grouping is correct
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERWARP requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"correct K element grouping");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPW] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
|
||||
// elements Divide by 32 (>>5) to map to K scale groups
|
||||
// (kBlockScaleSizeK=64)
|
||||
const index_t scale_idx = col_offset >> 5;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// dequant: combine q_descale (in s_acc_element_func) with k_descale
|
||||
auto s_acc_element_func_ = [&]() {
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
}
|
||||
// STAGE 2, scale_s, mask, softmax
|
||||
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - m + shift) = exp2(s - (m - shift)); pertensor path uses scale_s on s,m
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto m_new = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * m_new;
|
||||
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
|
||||
// Update l and rescale o_acc
|
||||
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
|
||||
// Apply per-channel v_descale after the loop (before normalization)
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v);
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v)); // store next v
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
|
||||
// before normalization)
|
||||
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
|
||||
block_sync_lds();
|
||||
|
||||
// V is col-major, each column (channel) has its own scale
|
||||
// o_acc shape: [M0, N1] where N1 is hdim_v
|
||||
// v_descale_ptr points to per-channel scales [hdim_v]
|
||||
// Load v_descale to LDS for better memory access pattern
|
||||
// Reuse K/V LDS space (they're no longer needed)
|
||||
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
|
||||
|
||||
// Cooperatively load v_descale to LDS
|
||||
const index_t num_threads = kBlockSize;
|
||||
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
|
||||
{
|
||||
v_descale_lds[i] = v_descale_ptr[i];
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// Get the global tile index for the N1 (channel) dimension
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
const index_t channel_idx = tile_idx.at(number<1>{});
|
||||
const float v_scale = v_descale_lds[channel_idx];
|
||||
o_acc(i_j_idx) *= v_scale;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
// When masking, the denominator can be zero; guard the normalization
|
||||
// so we do not divide by zero after a fully masked row.
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
AttnMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
q_descale_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,873 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
|
||||
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy>
|
||||
struct BlockSageAttentionPipelineQRKSVSAsync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::PDataType>;
|
||||
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
|
||||
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
|
||||
static_assert(std::is_same_v<PDataType, VDataType>,
|
||||
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<PDataType, fp8_t>,
|
||||
"SageAttention pipeline requires PDataType = fp8_t");
|
||||
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
|
||||
std::is_same_v<VDataType, fp8_t>,
|
||||
"SageAttention pipeline requires VDataType = fp8_t");
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
|
||||
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
|
||||
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
|
||||
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
|
||||
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
|
||||
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
|
||||
Problem::kPadHeadDimV == true);
|
||||
static constexpr bool kPadSeqLenQ = true;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
|
||||
|
||||
// FP8 softmax shift constants to map softmax output into representable FP8 range
|
||||
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
|
||||
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
|
||||
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
|
||||
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
|
||||
static constexpr float OCP_FP8_SHIFT = 8.0f;
|
||||
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_async";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& /*k_element_func*/,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
AttnMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
[[maybe_unused]] const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
|
||||
|
||||
// K tile in LDS
|
||||
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
auto k_lds_store = generate_tuple(
|
||||
[&](auto i_buf) {
|
||||
return make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
|
||||
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
|
||||
{0, 0, 0});
|
||||
},
|
||||
number<Policy::NumKVLdsBuffers>{});
|
||||
|
||||
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_load =
|
||||
make_tile_window(k_lds_Load_view,
|
||||
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
q_dram_window.init_raw();
|
||||
|
||||
// TODO: we use async Copy for K, which is inline asm
|
||||
// a side effect is we have to use inline asm for q as well
|
||||
auto q = decltype(load_tile(q_dram_window)){};
|
||||
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
|
||||
// however, q would be cleared in the constructor of static distributed tensor
|
||||
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
|
||||
load_tile_raw(q, q_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
|
||||
// infer Sacc, S, P, M, L, Oacc type
|
||||
using SBlockTileType =
|
||||
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
|
||||
SaccBlockTileType,
|
||||
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
|
||||
|
||||
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
|
||||
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
auto m = MLBlockTileType{};
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(start, end);
|
||||
}();
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
|
||||
// otherwise will have compute error(maybe compiler bug?)
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
return o_acc;
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dram_window.init_raw();
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = bool_constant<false>{};
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, kv_load_start},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
|
||||
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
|
||||
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
static_assert(1 <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(kGemm0MPerWarp == 32);
|
||||
constexpr index_t kWarpSz = get_warp_size();
|
||||
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
|
||||
// indexing)
|
||||
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
|
||||
// main loop
|
||||
do
|
||||
{
|
||||
float k_descale = 1.0f;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
k_descale = k_descale_ptr[kv_idx];
|
||||
}
|
||||
constexpr index_t kNumKScalesPW =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
|
||||
? kN0 / Problem::kBlockScaleSizeK
|
||||
: 1;
|
||||
constexpr index_t kNumKScalesPT =
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
|
||||
? kN0 / Problem::kBlockScaleSizeK / 2
|
||||
: 1;
|
||||
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
const index_t kv_idx =
|
||||
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
|
||||
}
|
||||
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
|
||||
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
|
||||
}
|
||||
|
||||
// STAGE 1, QK gemm
|
||||
auto s_acc_gemm = SaccBlockTileType{};
|
||||
clear_tile(s_acc_gemm); // initialize C
|
||||
if constexpr(k0_loops > 1)
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
async_load_fence(k_dram_window.get_num_of_access());
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
gemm_0(s_acc_gemm,
|
||||
get_slice_tile(
|
||||
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: this to fix a bug when loop smaller than 2,
|
||||
// the following fence/barrier will be scheduled inside 1st loop
|
||||
if constexpr(k0_loops <= 2)
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
async_load_fence();
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(
|
||||
s_acc_gemm,
|
||||
get_slice_tile(
|
||||
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
|
||||
get_slice_tile(k_lds_load,
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// Convert GEMM output to SaccDataType for softmax (if needed)
|
||||
auto s_acc = [&]() {
|
||||
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
|
||||
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
|
||||
{
|
||||
return s_acc_gemm; // No conversion needed (e.g., float -> float)
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// PERTHREAD: kBlockScaleSizeK=16
|
||||
// The s_acc tile distribution is determined by
|
||||
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
|
||||
// each thread processes exactly 16 consecutive elements in the K dimension. This
|
||||
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
|
||||
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
|
||||
// elements to K scale indices.
|
||||
static_assert(Problem::kBlockScaleSizeK == 16,
|
||||
"PERTHREAD: kBlockScaleSizeK must be 16");
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures the distribution has 16 consecutive K elements per
|
||||
// thread
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERTHREAD requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"16 consecutive K elements");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPT] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPT; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
|
||||
const index_t scale_idx = col_offset >> 4;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
|
||||
{
|
||||
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
|
||||
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
|
||||
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
|
||||
|
||||
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
|
||||
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
|
||||
// grouping is correct
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmCfg =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
|
||||
using ExpectedWarpGemmI8 =
|
||||
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
using ExpectedWarpGemmFp8 =
|
||||
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
|
||||
static_assert(
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
|
||||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
|
||||
"PERWARP requires "
|
||||
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
|
||||
"correct K element grouping");
|
||||
|
||||
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
|
||||
float combined_scales_reg[kNumKScalesPW] = {};
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kNumKScalesPW; i++)
|
||||
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
|
||||
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
|
||||
index_t col_offset = 0;
|
||||
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// col_offset counts columns in distributed view
|
||||
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
|
||||
// elements Divide by 32 (>>5) to map to K scale groups
|
||||
// (kBlockScaleSizeK=64)
|
||||
const index_t scale_idx = col_offset >> 5;
|
||||
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
|
||||
col_offset++;
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// dequant: combine q_descale (in s_acc_element_func) with k_descale
|
||||
auto s_acc_element_func_ = [&]() {
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
return s_acc_element_func * k_descale;
|
||||
}
|
||||
else
|
||||
return s_acc_element_func;
|
||||
}();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
|
||||
}
|
||||
// STAGE 2, scale_s, mask, softmax
|
||||
// logits_soft_cap is always disabled
|
||||
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
|
||||
s.get_tile_distribution()); // Pcompute{j}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7F);
|
||||
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
|
||||
// Only needed when K tail and V use the same LDS buffer
|
||||
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
// store & prefetch next v, after the max reduction
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
|
||||
store_tile(
|
||||
v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp =
|
||||
get_slice_tile(v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
|
||||
}
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
else
|
||||
{
|
||||
return raw_m;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
// For BLOCKSCALE: precompute (m - shift) once per row
|
||||
// exp2(s - m + shift) = exp2(s - (m - shift))
|
||||
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
|
||||
auto validated_m = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * validated_m;
|
||||
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
|
||||
row_max -= OCP_FP8_SHIFT; // for else branch
|
||||
#else
|
||||
validated_m -= FNUZ_FP8_SHIFT;
|
||||
row_max -= FNUZ_FP8_SHIFT;
|
||||
#endif
|
||||
}
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// logits_soft_cap is always disabled
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
});
|
||||
});
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
const auto m_new = get_validated_m(m[i_idx]);
|
||||
auto row_max = scale_s * m_new;
|
||||
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
|
||||
// Update l and rescale o_acc
|
||||
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
|
||||
// Apply per-channel v_descale after the loop (before normalization)
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_buf);
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
auto v_lds_window_tmp = get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
|
||||
store_tile(v_lds_window_tmp,
|
||||
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
|
||||
}
|
||||
if constexpr(i_k1 < k1_loops - 1)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(
|
||||
o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
get_slice_tile(
|
||||
v_lds_window,
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
|
||||
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
|
||||
}
|
||||
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
|
||||
// before normalization)
|
||||
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
|
||||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
|
||||
{
|
||||
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
|
||||
block_sync_lds();
|
||||
|
||||
// V is col-major, each column (channel) has its own scale
|
||||
// o_acc shape: [M0, N1] where N1 is hdim_v
|
||||
// v_descale_ptr points to per-channel scales [hdim_v]
|
||||
// Load v_descale to LDS for better memory access pattern
|
||||
// Reuse K/V LDS space (they're no longer needed)
|
||||
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
|
||||
|
||||
// Cooperatively load v_descale to LDS
|
||||
const index_t num_threads = kBlockSize;
|
||||
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
|
||||
{
|
||||
v_descale_lds[i] = v_descale_ptr[i];
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// Get the global tile index for the N1 (channel) dimension
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
const index_t channel_idx = tile_idx.at(number<1>{});
|
||||
const float v_scale = v_descale_lds[channel_idx];
|
||||
o_acc(i_j_idx) *= v_scale;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// finally, O
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(AttnMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
else
|
||||
return 1 / l[i_idx];
|
||||
}();
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
AttnMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
const float* q_descale_ptr = nullptr,
|
||||
const float* k_descale_ptr = nullptr,
|
||||
const float* v_descale_ptr = nullptr,
|
||||
[[maybe_unused]] float q_descale_value = 1.0f) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
q_descale_ptr,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
q_descale_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
using BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy =
|
||||
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ true,
|
||||
/* NumPrefetchK = */ 3,
|
||||
/* NumPrefetchV = */ 3>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,857 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetPackedSize()
|
||||
{
|
||||
return numeric_traits<remove_cvref_t<T>>::PackedSize;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetLogicalVectorSize(index_t bytes)
|
||||
{
|
||||
return (bytes / sizeof(remove_cvref_t<T>)) * GetPackedSize<T>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
using SageAttnQKGemmQDataType =
|
||||
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::QDataType>>,
|
||||
fp8_t,
|
||||
remove_cvref_t<typename Problem::QDataType>>;
|
||||
|
||||
template <typename Problem>
|
||||
using SageAttnQKGemmKDataType =
|
||||
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::KDataType>>,
|
||||
fp8_t,
|
||||
remove_cvref_t<typename Problem::KDataType>>;
|
||||
|
||||
template <bool QLoadOnce_>
|
||||
struct BlockSageAttnPipelineQRCustomPolicy;
|
||||
|
||||
template <>
|
||||
struct BlockSageAttnPipelineQRCustomPolicy</* QLoadOnce = */ true>
|
||||
{
|
||||
static constexpr bool QLoadOnce = true;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO: GetAlignment*() currently didn't consider if need padding or not
|
||||
// so in pipeline still need check padding requirement
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<typename Problem::QDataType>(16);
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kSubQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using QKGemmQDataType = SageAttnQKGemmQDataType<Problem>;
|
||||
using QKGemmKDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
// int8 MFMA accumulates to int32, but SaccDataType is float for softmax
|
||||
using GemmAccDataType =
|
||||
std::conditional_t<(std::is_same_v<QKGemmQDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmQDataType, signed char>) &&
|
||||
(std::is_same_v<QKGemmKDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmKDataType, signed char>),
|
||||
int32_t,
|
||||
typename Problem::SaccDataType>;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kN0,
|
||||
Problem::BlockSageAttnShape::kK0>,
|
||||
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockSageAttnShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(get_warp_size() == 64 && std::is_same_v<QKGemmQDataType, fp8_t> &&
|
||||
std::is_same_v<QKGemmKDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it produces incorrect results
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else if constexpr(get_warp_size() == 64 &&
|
||||
(std::is_same_v<QKGemmQDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmQDataType, signed char>) &&
|
||||
(std::is_same_v<QKGemmKDataType, int8_t> ||
|
||||
std::is_same_v<QKGemmKDataType, signed char>))
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
// Use special int8 MFMA with K iteration (similar to FP8)
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool SwizzleA =
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32;
|
||||
return WarpGemmDispatcher<
|
||||
QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true, // TransposeC
|
||||
SwizzleA>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
QKGemmQDataType,
|
||||
QKGemmKDataType,
|
||||
GemmAccDataType,
|
||||
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
|
||||
struct BlockSageAttnPipelineQRKSVSCustomPolicy : BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>
|
||||
{
|
||||
static constexpr bool AsyncCopy = AsyncCopy_;
|
||||
|
||||
static constexpr index_t NumPrefetchK = NumPrefetchK_;
|
||||
static constexpr index_t NumPrefetchV = NumPrefetchV_;
|
||||
|
||||
static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV);
|
||||
|
||||
using QXPolicy = BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>;
|
||||
|
||||
template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
|
||||
struct LdsBufferSequence
|
||||
{
|
||||
static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_);
|
||||
static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_;
|
||||
|
||||
// for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not
|
||||
// overlap with the Lds buffers used by first two gemm_0 iterations of K
|
||||
static constexpr auto Make()
|
||||
{
|
||||
// ensure v_loop_-1 is assigned to num_lds_buffers-1
|
||||
return transform_sequences(
|
||||
[&](auto i) {
|
||||
if(i < k_loops_)
|
||||
return i % num_lds_buffers_;
|
||||
else
|
||||
return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) %
|
||||
num_lds_buffers_;
|
||||
},
|
||||
typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{});
|
||||
};
|
||||
|
||||
using type = remove_cvref_t<decltype(Make())>;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
|
||||
|
||||
template<> struct
|
||||
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
|
||||
// clang-format on
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence()
|
||||
{
|
||||
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
|
||||
|
||||
constexpr index_t kN0 = BlockSageAttnShape::kN0;
|
||||
constexpr index_t kK0 = BlockSageAttnShape::kK0;
|
||||
constexpr index_t kK1 = BlockSageAttnShape::kK1;
|
||||
constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using KDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
return GetLogicalVectorSize<KDataType>(16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
if constexpr(AsyncCopy)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
|
||||
#else
|
||||
constexpr index_t MaxLoadSizeInBytes = 4; // dword
|
||||
#endif
|
||||
|
||||
return GetLogicalVectorSize<KDataType>(MaxLoadSizeInBytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (total_pixels / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
|
||||
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
// this function assume K/V can share smem
|
||||
constexpr index_t SingleKSize = [&]() {
|
||||
if constexpr(!AsyncCopy)
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
// TODO: this is used for non async copy desc. unify in the future
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, index_t IBuf = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
|
||||
{
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
|
||||
{
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKVLdsBuffers>{}, // num_buffers
|
||||
number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<BufferSize>{},
|
||||
number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumKVLdsBuffers>{},
|
||||
number<NumIssues>{},
|
||||
number<LaneGroups>{},
|
||||
number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = get_n_lds_banks();
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
number<NumKVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
// TODO: assume Q is in register
|
||||
// TODO: assume K and V share smem buffers
|
||||
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
|
||||
constexpr index_t single_smem_size =
|
||||
GetSingleSmemElementSpaceSize<Problem>() * sizeof(KLdsDataType);
|
||||
|
||||
return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size * NumKVLdsBuffers;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeKV<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
if constexpr(!AsyncCopy)
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
constexpr index_t N2 = NumWarps;
|
||||
constexpr index_t K0 = LanesPerK;
|
||||
constexpr index_t K1 = KVector;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
static_assert(kNPerBlock % 16 == 0);
|
||||
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1 / K0;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0);
|
||||
|
||||
constexpr auto dstr = make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>, // N0 K1
|
||||
sequence<0, 1>>{});
|
||||
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock)
|
||||
{
|
||||
return dstr;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKPerBlock % 16 == 0);
|
||||
constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0_m = kKPerBlock / kKPerIter;
|
||||
constexpr index_t K2 = 2;
|
||||
constexpr index_t K1_m = kKPerIter / K2;
|
||||
constexpr index_t N2_m = get_warp_size() / K1_m;
|
||||
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N0 K2
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(),
|
||||
std::multiplies<index_t>{},
|
||||
1) == kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
|
||||
{
|
||||
// This descriptor only used when V layout is seqlen * hdim
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
static_assert(kNPerBlock % 16 == 0);
|
||||
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1 / K0;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
|
||||
Problem::BlockSageAttnShape::kN1,
|
||||
Problem::BlockSageAttnShape::kK1>,
|
||||
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockSageAttnShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(get_warp_size() == 64 &&
|
||||
std::is_same_v<typename Problem::PDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::VDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::OaccDataType, float>)
|
||||
{
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}) == 32);
|
||||
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}) == 32);
|
||||
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using BlockSageAttentionPipelineQRKSVSDefaultPolicy =
|
||||
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t Headdim>
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
|
||||
{
|
||||
if constexpr(Headdim == 48)
|
||||
return 48;
|
||||
else if constexpr(Headdim == 80)
|
||||
return 96;
|
||||
else if constexpr(Headdim == 96)
|
||||
return 128;
|
||||
else if constexpr(Headdim == 160)
|
||||
return 256;
|
||||
else if constexpr(Headdim == 192)
|
||||
return 192;
|
||||
else if constexpr(is_power_of_two_integer(Headdim))
|
||||
return Headdim;
|
||||
else
|
||||
static_assert(Headdim == 0,
|
||||
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
bool IsVLayoutRowMajor_>
|
||||
struct TileSageAttnShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* padding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* padding for hdim_v */,
|
||||
BlockSageAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileSageAttnTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto QScaleEnum = QScaleEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
|
||||
/// Tokens per Q/K descale along seqlen. Fine-to-coarse: PERTHREAD, PERWARP, then 128 for Q
|
||||
/// (BLOCKSCALE / no_scale / pertensor). K: PERWARP 64, BLOCKSCALE 128, else 128.
|
||||
static constexpr index_t kBlockScaleSizeQ =
|
||||
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 4
|
||||
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 32
|
||||
: 128;
|
||||
static constexpr index_t kBlockScaleSizeK =
|
||||
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 16
|
||||
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 64
|
||||
: 128;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
17
include/ck_tile/ops/sageattn.hpp
Normal file
17
include/ck_tile/ops/sageattn.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp"
|
||||
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
Reference in New Issue
Block a user