[rocm-libraries] ROCm/rocm-libraries#6574 (commit b3db057)

[CK_TILE] Add SageAttention v2 forward kernel with
 multi-granularity quantization (#6574)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Add a CK_TILE forward kernel implementing [SageAttention
v2](https://arxiv.org/abs/2411.10958) — an attention algorithm that
applies multi-granularity quantization to Q/K/V before computing
attention, trading minimal accuracy loss for higher throughput on
low-precision hardware.

### Quantization design

| Tensor | Supported data types | Scale granularity options |
|--------|---------------------|--------------------------|
| Q | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(32 tokens), per-thread (4 tokens) |
| K | fp8 / int8 / int4 | per-tensor, per-block (128 tokens), per-warp
(64 tokens), per-thread (16 tokens) |
| V | fp8 | per-channel (always) |
| O | bf16 | — |

Three precision combinations are supported: `fp8/bf16` (QKV fp8, O
bf16), `i8/fp8/bf16` (QK int8, V fp8, O bf16), and `i4/fp8/bf16` (QK
int4, V fp8, O bf16).

### Architecture support

- **gfx9** (CDNA2/3, e.g. gfx90a, gfx942) — full tile set
- **gfx950** (CDNA4) — restricted tile set (N-per-block capped at 64 for
fp8-family dtypes)

### Implementation

- Two pipeline variants: `QRKSVS` (synchronous) and `QRKSVS_ASYNC`
(async copy)
- Masking support: no mask, causal (top-left / bottom-right), and
generic windowed
- Batch and group (variable-length) modes
- Head dimension: d=128, d_v=128
- Python codegen under `example/ck_tile/49_sageattention/codegen/`
generates kernel instances per target/dtype/tile combination
- Smoke tests included via `tile_example_sageattn_fwd`

### Test commands

\`\`\`bash
# fp8 QKV
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=fp8bf16 -qscale=3 -init=3

# int8 QK, fp8 V
./build/bin/tile_example_sageattn_fwd -v=1 -b=16 -h=8 -s=1024 -d=128
-kname=1 -prec=i8fp8bf16 -qscale=3 -init=3
\`\`\`

\`-qscale\` values: 1=per-tensor, 2=per-block, 3=per-warp, 4=per-thread
This commit is contained in:
ltqin
2026-04-30 18:33:36 +00:00
committed by assistant-librarian[bot]
parent e8d64ad5c6
commit de0a61e5c2
30 changed files with 7809 additions and 0 deletions

View File

@@ -530,4 +530,10 @@ using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;
} // namespace ck_tile

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockSageAttentionQuantScaleEnum
{
NO_SCALE = 0,
PERTENSOR = 1,
BLOCKSCALE = 2,
PERWARP = 3,
PERTHREAD = 4,
};
template <BlockSageAttentionQuantScaleEnum>
struct BlockSageAttentionQuantScaleEnumToStr;
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::NO_SCALE>
{
static constexpr const char* name = "";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTENSOR>
{
static constexpr const char* name = "pertensor";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::BLOCKSCALE>
{
static constexpr const char* name = "blockscale";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERWARP>
{
static constexpr const char* name = "perwarp";
};
template <>
struct BlockSageAttentionQuantScaleEnumToStr<BlockSageAttentionQuantScaleEnum::PERTHREAD>
{
static constexpr const char* name = "perthread";
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockSageAttnPipelineEnum
{
QRKSVS = 0,
QRKSVS_ASYNC,
};
template <BlockSageAttnPipelineEnum>
struct BlockSageAttnPipelineEnumToStr;
template <>
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockSageAttnPipelineEnumToStr<BlockSageAttnPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
} // namespace ck_tile

View File

@@ -0,0 +1,60 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename BlockSageAttnShape_,
bool kIsGroupMode_,
typename AttentionVariant_,
typename AttnMask_,
typename Traits_>
struct BlockSageAttnPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockSageAttnShape = remove_cvref_t<BlockSageAttnShape_>;
using AttentionVariant = remove_cvref_t<AttentionVariant_>;
using AttnMask = remove_cvref_t<AttnMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = BlockSageAttnShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = BlockSageAttnShape::NumGemm1Warps;
static constexpr index_t kBlockSize = BlockSageAttnShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
/// Must match host scale tensor layout (same values as TileSageAttnTraits for Sage kernels).
static constexpr index_t kBlockScaleSizeQ = Traits::kBlockScaleSizeQ;
static constexpr index_t kBlockScaleSizeK = Traits::kBlockScaleSizeK;
};
} // namespace ck_tile

View File

@@ -0,0 +1,861 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSDefaultPolicy>
struct BlockSageAttentionPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using QGemmDataType = SageAttnQKGemmQDataType<Problem>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
static_assert(std::is_same_v<PDataType, VDataType>,
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<PDataType, fp8_t>,
"SageAttention pipeline requires PDataType = fp8_t");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<VDataType, fp8_t>,
"SageAttention pipeline requires VDataType = fp8_t");
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr uint32_t DS_READ = 0x100; // Barrier for DS (data share) read
static constexpr uint32_t MFMA = 0x008; // Barrier for MFMA (matrix multiply-accumulate)
// FP8 softmax shift constants to map softmax output into representable FP8 range
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr auto QScaleEnum = Problem::QScaleEnum;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
}
}
}();
static constexpr const char* name = "qr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
AttnMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// K tile in LDS
KLdsDataType* k_lds_ptr = static_cast<KLdsDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window_reg =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window_reg);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
SaccBlockTileType,
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(start, end);
}();
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
}
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start},
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = [&]() {
if constexpr(std::is_same_v<QDataType, QGemmDataType>)
return tile_elementwise_in(q_element_func, q);
else
{
auto q_tile_tmp = make_static_distributed_tensor<QGemmDataType>(
Policy::template MakeQRegTileDistribution<Problem>());
constexpr index_t kPackedSize = numeric_traits<QDataType>::PackedSize;
constexpr index_t kUnaryOpSize = 8;
static_assert(std::is_same_v<QDataType, ck_tile::pk_int4_t>);
static_assert(kPackedSize == 2);
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() ==
decltype(q)::get_thread_buffer_size() * kPackedSize);
static_assert(decltype(q_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
using RawQType = typename QDataType::type;
using SrcVectorType = ext_vector_t<RawQType, kUnaryOpSize / kPackedSize>;
using DstVectorType = ext_vector_t<QGemmDataType, kUnaryOpSize>;
constexpr index_t kVecSize =
decltype(q_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
static_assert(decltype(q)::get_thread_buffer_size() ==
kVecSize * (kUnaryOpSize / kPackedSize));
const element_wise::PassThroughPack8 pass_through_pack8{};
static_for<0, kVecSize, 1>{}([&](auto i) {
pass_through_pack8(
q_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(i),
q.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
return q_tile_tmp;
}
}();
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm0 = [] {
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
constexpr index_t Gemm0MWarp = WarpGemmConfig.template at<1>();
constexpr index_t Gemm0NWarp = WarpGemmConfig.template at<2>();
constexpr index_t WarpGemm0M = WarpGemm0::WarpGemmAttribute::Impl::kM;
constexpr index_t WarpGemm0N = WarpGemm0::WarpGemmAttribute::Impl::kN;
constexpr index_t WarpGemm0K = WarpGemm0::WarpGemmAttribute::Impl::kK;
constexpr index_t NumMfmaInsts = (kM0 / WarpGemm0M) * (kN0 / WarpGemm0N) *
(kK0 / WarpGemm0K) / (Gemm0MWarp * Gemm0NWarp);
if constexpr(get_warp_size() == 64 && kQKHeaddim == 256)
{
static_assert(NumMfmaInsts % 8 == 0);
static_for<0, NumMfmaInsts / 8, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(DS_READ, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(DS_READ, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(MFMA, 4, 0); // MFMA
});
}
};
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
static_assert(get_warp_size() % kGemm0MPerWarp == 0);
constexpr index_t kWarpSz = get_warp_size();
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
// indexing)
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
// main loop
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
k_descale = k_descale_ptr[kv_idx];
}
constexpr index_t kNumKScalesPW =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
? kN0 / Problem::kBlockScaleSizeK
: 1;
constexpr index_t kNumKScalesPT =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
? kN0 / Problem::kBlockScaleSizeK / 2
: 1;
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
}
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
}
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto s_acc_gemm = SaccBlockTileType{};
const auto store_k_block_tile_to_lds = [&](const auto& k_block_tile_) {
if constexpr(std::is_same_v<KDataType, KLdsDataType>)
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile_));
else
{
auto k_block_tile_tmp = make_static_distributed_tensor<KLdsDataType>(
k_dram_window.get_tile_distribution());
using KBlockTileType = remove_cvref_t<decltype(k_block_tile_)>;
constexpr index_t kPackedSize = numeric_traits<KDataType>::PackedSize;
constexpr index_t kUnaryOpSize = 8;
static_assert(std::is_same_v<KDataType, ck_tile::pk_int4_t>);
static_assert(kPackedSize == 2);
static_assert(decltype(k_block_tile_tmp)::get_thread_buffer_size() ==
KBlockTileType::get_thread_buffer_size() * kPackedSize);
static_assert(
decltype(k_block_tile_tmp)::get_thread_buffer_size() % kUnaryOpSize == 0);
using RawKType = typename KDataType::type;
using SrcVectorType = ext_vector_t<RawKType, kUnaryOpSize / kPackedSize>;
using DstVectorType = ext_vector_t<KLdsDataType, kUnaryOpSize>;
constexpr index_t kVecSize =
decltype(k_block_tile_tmp)::get_thread_buffer_size() / kUnaryOpSize;
static_assert(KBlockTileType::get_thread_buffer_size() ==
kVecSize * (kUnaryOpSize / kPackedSize));
const element_wise::PassThroughPack8 pass_through_pack8{};
static_for<0, kVecSize, 1>{}([&](auto i) {
pass_through_pack8(
k_block_tile_tmp.get_thread_buffer().template get_as<DstVectorType>()(
i),
k_block_tile_.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
store_tile(k_lds_window, k_block_tile_tmp);
}
};
auto k_block_tile = load_tile(k_dram_window);
{
move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc_gemm); // initialize C
store_k_block_tile_to_lds(k_block_tile);
k_block_tile = load_tile(k_dram_window);
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_k_block_tile_to_lds(k_block_tile); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
schedule_gemm0();
block_sync_lds();
store_k_block_tile_to_lds(k_block_tile);
block_sync_lds();
gemm_0(s_acc_gemm,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
schedule_gemm0();
}
// Convert GEMM output to SaccDataType for softmax (if needed)
auto s_acc = [&]() {
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
{
return s_acc_gemm; // No conversion needed (e.g., float -> float)
}
else
{
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
}
}();
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// PERTHREAD: kBlockScaleSizeK=16
// The s_acc tile distribution is determined by
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
// each thread processes exactly 16 consecutive elements in the K dimension. This
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
// elements to K scale indices.
static_assert(Problem::kBlockScaleSizeK == 16,
"PERTHREAD: kBlockScaleSizeK must be 16");
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures the distribution has 16 consecutive K elements per
// thread
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERTHREAD requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"16 consecutive K elements");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPT] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
const index_t scale_idx = col_offset >> 4;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
// grouping is correct
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERWARP requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"correct K element grouping");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPW] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
// elements Divide by 32 (>>5) to map to K scale groups
// (kBlockScaleSizeK=64)
const index_t scale_idx = col_offset >> 5;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else
{
// dequant: combine q_descale (in s_acc_element_func) with k_descale
auto s_acc_element_func_ = [&]() {
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
}
// STAGE 2, scale_s, mask, softmax
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(AttnMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
// For BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - m + shift) = exp2(s - (m - shift)); pertensor path uses scale_s on s,m
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_new = get_validated_m(m[i_idx]);
auto row_max = scale_s * m_new;
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
// Update l and rescale o_acc
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
// Apply per-channel v_descale after the loop (before normalization)
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
move_tile_window(v_dram_window, {0, kK1});
});
}
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
v_lds_window);
block_sync_lds();
}
} while(++i_total_loops < num_total_loop);
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
// before normalization)
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
block_sync_lds();
// V is col-major, each column (channel) has its own scale
// o_acc shape: [M0, N1] where N1 is hdim_v
// v_descale_ptr points to per-channel scales [hdim_v]
// Load v_descale to LDS for better memory access pattern
// Reuse K/V LDS space (they're no longer needed)
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
// Cooperatively load v_descale to LDS
const index_t num_threads = kBlockSize;
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
{
v_descale_lds[i] = v_descale_ptr[i];
}
block_sync_lds();
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// Get the global tile index for the N1 (channel) dimension
const auto tile_idx = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const index_t channel_idx = tile_idx.at(number<1>{});
const float v_scale = v_descale_lds[channel_idx];
o_acc(i_j_idx) *= v_scale;
});
});
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
// When masking, the denominator can be zero; guard the normalization
// so we do not divide by zero after a fully masked row.
if constexpr(AttnMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
AttnMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
q_descale_value);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,873 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy>
struct BlockSageAttentionPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
// fp16/bf16 example configs use P=V=fp16/bf16 (qscale=no). Quantized Sage paths use fp8 P/V;
// FP8 softmax shift, v_descale, and PV-gemm LDS layout assume fp8_t for those cases.
static_assert(std::is_same_v<PDataType, VDataType>,
"SageAttention pipeline requires PDataType == VDataType for the PV gemm");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<PDataType, fp8_t>,
"SageAttention pipeline requires PDataType = fp8_t");
static_assert(std::is_same_v<QDataType, half_t> || std::is_same_v<QDataType, bf16_t> ||
std::is_same_v<VDataType, fp8_t>,
"SageAttention pipeline requires VDataType = fp8_t");
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using AttnMask = remove_cvref_t<typename Problem::AttnMask>;
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
using VLayout = remove_cvref_t<typename BlockSageAttnShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockSageAttnShape::kM0;
static constexpr index_t kN0 = BlockSageAttnShape::kN0;
static constexpr index_t kK0 = BlockSageAttnShape::kK0;
static constexpr index_t kN1 = BlockSageAttnShape::kN1;
static constexpr index_t kK1 = BlockSageAttnShape::kK1;
static constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockSageAttnShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto QScaleEnum = Problem::QScaleEnum;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
// FP8 softmax shift constants to map softmax output into representable FP8 range
// OCP E4M3 FP8: max exponent = 8, max value ~240 (2^8 * 1.875)
// Use shift=8.0 so exp2(s - m - 8) maps softmax to [0, 2^8] range
// FNUZ E4M3 FP8: max exponent = 7, max value ~120 (2^7 * 1.875)
// Use shift=7.0 so exp2(s - m - 7) maps softmax to [0, 2^7] range
static constexpr float OCP_FP8_SHIFT = 8.0f;
static constexpr float FNUZ_FP8_SHIFT = 7.0f;
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
return 2;
}
else if constexpr(kQKHeaddim <= 192)
{
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
AttnMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
[[maybe_unused]] const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumKVLdsBuffers>{});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType =
std::conditional_t<std::is_same_v<typename SaccBlockTileType::DataType, SaccDataType>,
SaccBlockTileType,
decltype(cast_tile<SaccDataType>(SaccBlockTileType{}))>;
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
{
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
}
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto tile_range_result = [&mask, &q_origin]() {
auto [start, end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
return ck_tile::make_tuple(start, end);
}();
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<0>{});
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<1>{});
const auto kv_load_start = seqlen_k_start > 0 ? seqlen_k_start : 0;
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if no work to do
if constexpr(AttnMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{kv_load_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<false>{};
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, kv_load_start},
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
constexpr index_t kGemm0MPerWarp = BlockSageAttnShape::Gemm0WarpTile::at(number<0>{});
static_assert(kGemm0MPerWarp == 32);
constexpr index_t kWarpSz = get_warp_size();
// sub_warp_idx is 0 or 1, indicating which half of the warp (used for PERTHREAD K-scale
// indexing)
index_t sub_warp_idx = (threadIdx.x % kWarpSz) / kGemm0MPerWarp;
// main loop
do
{
float k_descale = 1.0f;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
k_descale = k_descale_ptr[kv_idx];
}
constexpr index_t kNumKScalesPW =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP
? kN0 / Problem::kBlockScaleSizeK
: 1;
constexpr index_t kNumKScalesPT =
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD
? kN0 / Problem::kBlockScaleSizeK / 2
: 1;
float k_scales_perwarp[kNumKScalesPW > 0 ? kNumKScalesPW : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
const index_t kv_idx =
(seqlen_k_start + i_total_loops * kN0) / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
k_scales_perwarp[i] = k_descale_ptr[kv_idx + i];
}
float k_scales_reg[kNumKScalesPT > 0 ? kNumKScalesPT : 1] = {};
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
const index_t k_global_start = seqlen_k_start + i_total_loops * kN0;
const index_t k_scale_start_idx = k_global_start / Problem::kBlockScaleSizeK;
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
k_scales_reg[i] = k_descale_ptr[k_scale_start_idx + 2 * i + sub_warp_idx];
}
// STAGE 1, QK gemm
auto s_acc_gemm = SaccBlockTileType{};
clear_tile(s_acc_gemm); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc_gemm,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(
s_acc_gemm,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
}
__builtin_amdgcn_sched_barrier(1);
// Convert GEMM output to SaccDataType for softmax (if needed)
auto s_acc = [&]() {
using GemmDataType = typename decltype(s_acc_gemm)::DataType;
if constexpr(std::is_same_v<GemmDataType, SaccDataType>)
{
return s_acc_gemm; // No conversion needed (e.g., float -> float)
}
else
{
return cast_tile<SaccDataType>(s_acc_gemm); // Convert (e.g., int32 -> float)
}
}();
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// PERTHREAD: kBlockScaleSizeK=16
// The s_acc tile distribution is determined by
// WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution, which guarantees
// each thread processes exactly 16 consecutive elements in the K dimension. This
// distribution is inherent to the MFMA 32x32x16 instruction with kKIter=2 and
// TransposedC layout. Therefore, col_offset >> 4 correctly maps thread-local
// elements to K scale indices.
static_assert(Problem::kBlockScaleSizeK == 16,
"PERTHREAD: kBlockScaleSizeK must be 16");
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures the distribution has 16 consecutive K elements per
// thread
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERTHREAD requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"16 consecutive K elements");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPT] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPT; i++)
combined_scales_reg[i] = q_descale_value * k_scales_reg[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// Divide by 16 (>>4) to map to K scale groups (kBlockScaleSizeK=16)
const index_t scale_idx = col_offset >> 4;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP)
{
// PERWARP: kBlockScaleSizeK=64, i.e., 64 global K elements share one scale
// Distribution: thread_i and thread_(i+32) interleave to cover K dimension
// In each thread's view, every 32 idx1 steps correspond to 64 global K elements
// Validate the WarpGemm type matches the expected MFMA instruction with SwizzleB +
// TransposedC This ensures each thread has 16 consecutive elements, and warp-level
// grouping is correct
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmCfg =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0Type = remove_cvref_t<decltype(WarpGemmCfg.template at<0>())>;
using ExpectedWarpGemmI8 =
WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<4>;
using ExpectedWarpGemmFp8 =
WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<4>;
static_assert(
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmI8> ||
std::is_same_v<WarpGemm0Type, ExpectedWarpGemmFp8>,
"PERWARP requires "
"WarpGemmMfma[I8I8I32|Fp8Fp8F32]M32N32K32SwizzleBTransposedCDistribution for "
"correct K element grouping");
constexpr auto s_acc_spans = decltype(s_acc)::get_distributed_spans();
float combined_scales_reg[kNumKScalesPW] = {};
#pragma unroll
for(index_t i = 0; i < kNumKScalesPW; i++)
combined_scales_reg[i] = q_descale_value * k_scales_perwarp[i];
sweep_tile_span(s_acc_spans[number<0>{}], [&](auto idx0) {
index_t col_offset = 0;
sweep_tile_span(s_acc_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// col_offset counts columns in distributed view
// When N0=64: each thread has 32 elements; when N0=128: each thread has 64
// elements Divide by 32 (>>5) to map to K scale groups
// (kBlockScaleSizeK=64)
const index_t scale_idx = col_offset >> 5;
s_acc(i_j_idx) *= combined_scales_reg[scale_idx];
col_offset++;
});
});
}
else
{
// dequant: combine q_descale (in s_acc_element_func) with k_descale
auto s_acc_element_func_ = [&]() {
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE)
{
return s_acc_element_func * k_descale;
}
else
return s_acc_element_func;
}();
s_acc = tile_elementwise_in(s_acc_element_func_, s_acc);
}
// STAGE 2, scale_s, mask, softmax
// logits_soft_cap is always disabled
if constexpr(kPadSeqLenK || AttnMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
auto apply_mask = [&](auto&& mask_func) {
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return !mask_func(variant_params,
block_indices.batch_idx,
row,
col,
block_indices.qo_head_idx,
block_indices.kv_head_idx);
});
};
apply_mask([&](auto&&... args) {
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store
// Only needed when K tail and V use the same LDS buffer
if constexpr(LdsSeq.at(number<k0_loops - 1>{}) == LdsSeq.at(number<k0_loops>{}))
{
__builtin_amdgcn_s_barrier();
}
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
if constexpr(AttnMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
// For BLOCKSCALE: precompute (m - shift) once per row
// exp2(s - m + shift) = exp2(s - (m - shift))
// else: exp2(scale_s*s - scale_s*m + shift) = exp2(scale_s*s - (scale_s*m - shift))
auto validated_m = get_validated_m(m[i_idx]);
auto row_max = scale_s * validated_m;
if constexpr(QScaleEnum == BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERWARP ||
QScaleEnum == BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
#if CK_TILE_USE_OCP_FP8
validated_m -= OCP_FP8_SHIFT; // OCP FP8 softmax shift
row_max -= OCP_FP8_SHIFT; // for else branch
#else
validated_m -= FNUZ_FP8_SHIFT;
row_max -= FNUZ_FP8_SHIFT;
#endif
}
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// logits_soft_cap is always disabled
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto m_new = get_validated_m(m[i_idx]);
auto row_max = scale_s * m_new;
const auto tmp = exp2(scale_s * m_old[i_idx] - row_max);
// Update l and rescale o_acc
l(i_idx) = tmp * l(i_idx) + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#else
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
#endif
}();
// STAGE 3, KV gemm
// For BLOCKSCALE, PERWARP, and PERTHREAD modes, accumulate directly to o_acc
// Apply per-channel v_descale after the loop (before normalization)
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// Apply per-channel v_descale for BLOCKSCALE, PERWARP, and PERTHREAD modes (after loop,
// before normalization)
if constexpr(Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::BLOCKSCALE ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERWARP ||
Problem::QScaleEnum == ck_tile::BlockSageAttentionQuantScaleEnum::PERTHREAD)
{
// Ensure all V LDS reads from the last gemm_1 complete before reusing K/V LDS space
block_sync_lds();
// V is col-major, each column (channel) has its own scale
// o_acc shape: [M0, N1] where N1 is hdim_v
// v_descale_ptr points to per-channel scales [hdim_v]
// Load v_descale to LDS for better memory access pattern
// Reuse K/V LDS space (they're no longer needed)
auto v_descale_lds = reinterpret_cast<float*>(smem_ptr);
// Cooperatively load v_descale to LDS
const index_t num_threads = kBlockSize;
for(index_t i = threadIdx.x; i < kN1; i += num_threads)
{
v_descale_lds[i] = v_descale_ptr[i];
}
block_sync_lds();
constexpr auto o_tmp_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_tmp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_tmp_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// Get the global tile index for the N1 (channel) dimension
const auto tile_idx = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const index_t channel_idx = tile_idx.at(number<1>{});
const float v_scale = v_descale_lds[channel_idx];
o_acc(i_j_idx) *= v_scale;
});
});
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(AttnMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
AttnMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
const float* q_descale_ptr = nullptr,
const float* k_descale_ptr = nullptr,
const float* v_descale_ptr = nullptr,
[[maybe_unused]] float q_descale_value = 1.0f) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
q_descale_ptr,
k_descale_ptr,
v_descale_ptr,
q_descale_value);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,18 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockSageAttentionPipelineQRKSVSAsyncDefaultPolicy =
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ true,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile

View File

@@ -0,0 +1,857 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
namespace ck_tile {
template <typename T>
CK_TILE_HOST_DEVICE static constexpr index_t GetPackedSize()
{
return numeric_traits<remove_cvref_t<T>>::PackedSize;
}
template <typename T>
CK_TILE_HOST_DEVICE static constexpr index_t GetLogicalVectorSize(index_t bytes)
{
return (bytes / sizeof(remove_cvref_t<T>)) * GetPackedSize<T>();
}
template <typename Problem>
using SageAttnQKGemmQDataType =
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::QDataType>>,
fp8_t,
remove_cvref_t<typename Problem::QDataType>>;
template <typename Problem>
using SageAttnQKGemmKDataType =
std::conditional_t<is_packed_type_v<remove_cvref_t<typename Problem::KDataType>>,
fp8_t,
remove_cvref_t<typename Problem::KDataType>>;
template <bool QLoadOnce_>
struct BlockSageAttnPipelineQRCustomPolicy;
template <>
struct BlockSageAttnPipelineQRCustomPolicy</* QLoadOnce = */ true>
{
static constexpr bool QLoadOnce = true;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return 0;
}
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = GetLogicalVectorSize<typename Problem::QDataType>(16);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kSubQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using QKGemmQDataType = SageAttnQKGemmQDataType<Problem>;
using QKGemmKDataType = SageAttnQKGemmKDataType<Problem>;
// int8 MFMA accumulates to int32, but SaccDataType is float for softmax
using GemmAccDataType =
std::conditional_t<(std::is_same_v<QKGemmQDataType, int8_t> ||
std::is_same_v<QKGemmQDataType, signed char>) &&
(std::is_same_v<QKGemmKDataType, int8_t> ||
std::is_same_v<QKGemmKDataType, signed char>),
int32_t,
typename Problem::SaccDataType>;
using GemmProblem =
BlockGemmProblem<QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kN0,
Problem::BlockSageAttnShape::kK0>,
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
typename Problem::BlockSageAttnShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(get_warp_size() == 64 && std::is_same_v<QKGemmQDataType, fp8_t> &&
std::is_same_v<QKGemmKDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
// TODO: hard coded here. Otherwise, it produces incorrect results
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else if constexpr(get_warp_size() == 64 &&
(std::is_same_v<QKGemmQDataType, int8_t> ||
std::is_same_v<QKGemmQDataType, signed char>) &&
(std::is_same_v<QKGemmKDataType, int8_t> ||
std::is_same_v<QKGemmKDataType, signed char>))
{
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}) == 32);
// Use special int8 MFMA with K iteration (similar to FP8)
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaI8I8I32M32N32K32SwizzleBTransposedCDistribution<
swizzle_factor>{};
}
else
{
constexpr bool SwizzleA =
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}) == 32;
return WarpGemmDispatcher<
QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockSageAttnShape::Gemm0WarpTile::at(number<2>{}),
true, // TransposeC
SwizzleA>{};
}
}();
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
QKGemmQDataType,
QKGemmKDataType,
GemmAccDataType,
typename Problem::BlockSageAttnShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
};
// This pipeline is qkv all located in LDS
template <bool QLoadOnce_, bool AsyncCopy_, index_t NumPrefetchK_, index_t NumPrefetchV_>
struct BlockSageAttnPipelineQRKSVSCustomPolicy : BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>
{
static constexpr bool AsyncCopy = AsyncCopy_;
static constexpr index_t NumPrefetchK = NumPrefetchK_;
static constexpr index_t NumPrefetchV = NumPrefetchV_;
static constexpr index_t NumKVLdsBuffers = max(NumPrefetchK, NumPrefetchV);
using QXPolicy = BlockSageAttnPipelineQRCustomPolicy<QLoadOnce_>;
template <index_t k_prefetches_, index_t v_prefetches_, index_t k_loops_, index_t v_loops_>
struct LdsBufferSequence
{
static constexpr index_t num_lds_buffers_ = max(k_prefetches_, v_prefetches_);
static constexpr index_t ceil_ = ((v_loops_ - 1) / num_lds_buffers_) * num_lds_buffers_;
// for qr_ks_vs_async, the Lds buffer assigned to last gemm_1 iteration of V should not
// overlap with the Lds buffers used by first two gemm_0 iterations of K
static constexpr auto Make()
{
// ensure v_loop_-1 is assigned to num_lds_buffers-1
return transform_sequences(
[&](auto i) {
if(i < k_loops_)
return i % num_lds_buffers_;
else
return ((num_lds_buffers_ - 1) + (i - k_loops_ + ceil_ - (v_loops_ - 1))) %
num_lds_buffers_;
},
typename arithmetic_sequence_gen<0, k_loops_ + v_loops_, 1>::type{});
};
using type = remove_cvref_t<decltype(Make())>;
};
// clang-format off
template<> struct
LdsBufferSequence<3, 3, 4, 4> { using type = sequence<1, 2, 0, 1, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 4, 2> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 4> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetLdsBufferSequence()
{
using BlockSageAttnShape = remove_cvref_t<typename Problem::BlockSageAttnShape>;
constexpr index_t kN0 = BlockSageAttnShape::kN0;
constexpr index_t kK0 = BlockSageAttnShape::kK0;
constexpr index_t kK1 = BlockSageAttnShape::kK1;
constexpr index_t kQKHeaddim = BlockSageAttnShape::kQKHeaddim;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
// TODO: this is for 3d layout
using KDataType = SageAttnQKGemmKDataType<Problem>;
return GetLogicalVectorSize<KDataType>(16);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
if constexpr(AsyncCopy)
{
#if defined(__gfx950__)
constexpr index_t MaxLoadSizeInBytes = 4 * 4; // dwordx4
#else
constexpr index_t MaxLoadSizeInBytes = 4; // dword
#endif
return GetLogicalVectorSize<KDataType>(MaxLoadSizeInBytes);
}
else
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
return kMaxVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad =
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
}
else
{
return kMaxVecLoad;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::ODataType);
return min(MaxVectorSize, WG::WarpGemmAttribute::Impl::kCM1PerLane);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
if constexpr(!AsyncCopy)
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
}
else
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock &&
WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
// constexpr index_t SingleVSize =
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t BufferSize =
GetSingleSmemElementSpaceSize<Problem>(); // max(SingleKSize, SingleVSize);
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKVLdsBuffers>{}, // num_buffers
number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<BufferSize>{},
number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumKVLdsBuffers>{},
number<NumIssues>{},
number<LaneGroups>{},
number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 1, 3, 2>{}, sequence<4, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<GetSingleSmemElementSpaceSize<Problem>()>{},
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(
number<NumKVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
// TODO: assume Q is in register
// TODO: assume K and V share smem buffers
using KLdsDataType = SageAttnQKGemmKDataType<Problem>;
constexpr index_t single_smem_size =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(KLdsDataType);
return QXPolicy::template GetSmemSizeQ<Problem>() + single_smem_size * NumKVLdsBuffers;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeKV<Problem>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{
if constexpr(!AsyncCopy)
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t MaxVectorSize = GetLogicalVectorSize<KDataType>(16);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockSageAttnShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 2, 1>, // N0 K2 N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, // N0 K1
sequence<0, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
{
return dstr;
}
else
{
static_assert(kKPerBlock % 16 == 0);
constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t N2_m = get_warp_size() / K1_m;
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>, // K0 N0 K2
sequence<0, 0, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return dstr_m;
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegBlockDescriptor()
{
// This descriptor only used when V layout is seqlen * hdim
using VLayout = remove_cvref_t<typename Problem::BlockSageAttnShape::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockSageAttnShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockSageAttnShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
tuple<sequence<0>, sequence<1, 1>>,
sequence<1, 1, 2>, // N0 K2 <-> N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockSageAttnShape::kM0,
Problem::BlockSageAttnShape::kN1,
Problem::BlockSageAttnShape::kK1>,
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
typename Problem::BlockSageAttnShape::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(get_warp_size() == 64 &&
std::is_same_v<typename Problem::PDataType, fp8_t> &&
std::is_same_v<typename Problem::VDataType, fp8_t> &&
std::is_same_v<typename Problem::OaccDataType, float>)
{
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}) == 32);
static_assert(Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}) == 32);
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
}
else
{
return WarpGemmDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockSageAttnShape::Gemm1WarpTile::at(number<2>{}),
true>{};
}
}();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::BlockSageAttnShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_custom_policy.hpp"
namespace ck_tile {
using BlockSageAttentionPipelineQRKSVSDefaultPolicy =
BlockSageAttnPipelineQRKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
} // namespace ck_tile

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t Headdim>
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
{
if constexpr(Headdim == 48)
return 48;
else if constexpr(Headdim == 80)
return 96;
else if constexpr(Headdim == 96)
return 128;
else if constexpr(Headdim == 160)
return 256;
else if constexpr(Headdim == 192)
return 192;
else if constexpr(is_power_of_two_integer(Headdim))
return Headdim;
else
static_assert(Headdim == 0,
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
struct TileSageAttnShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,42 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/sageattention/block/block_sageattention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* padding for hdim_q */,
bool kPadHeadDimV_ /* padding for hdim_v */,
BlockSageAttentionQuantScaleEnum QScaleEnum_,
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileSageAttnTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto QScaleEnum = QScaleEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
/// Tokens per Q/K descale along seqlen. Fine-to-coarse: PERTHREAD, PERWARP, then 128 for Q
/// (BLOCKSCALE / no_scale / pertensor). K: PERWARP 64, BLOCKSCALE 128, else 128.
static constexpr index_t kBlockScaleSizeQ =
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 4
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 32
: 128;
static constexpr index_t kBlockScaleSizeK =
QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERTHREAD ? 16
: QScaleEnum_ == BlockSageAttentionQuantScaleEnum::PERWARP ? 64
: 128;
};
} // namespace ck_tile

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/sageattention/kernel/sageattn_fwd_kernel.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_enum.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_problem.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/sageattention/pipeline/block_sageattn_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_shape.hpp"
#include "ck_tile/ops/sageattention/pipeline/tile_sageattn_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"