Merge branch 'develop' into users/yiding12/fmha-bwd-workspace

This commit is contained in:
Yi DING
2026-04-27 15:07:41 +08:00
committed by GitHub
50 changed files with 5216 additions and 1120 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
@@ -55,6 +56,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines.
// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool)
// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache)
enum class BlockAttentionKVCacheLoadModeEnum
{
BUFFER_LOAD = 0,
GLOBAL_LOAD_LDS = 1,
};
} // namespace ck_tile

View File

@@ -32,6 +32,83 @@
namespace ck_tile {
namespace detail {
// A helper struct for detecting n0loop
template <typename T, typename = void>
struct has_n0loop_flag : std::false_type
{
};
template <typename T>
struct has_n0loop_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseN0Loop), bool> && T::kUseN0Loop>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag<T>::value;
// A helper struct for detecting ignore_fast_exp2 flag
template <typename T, typename = void>
struct has_ignore_fast_exp2_flag : std::false_type
{
};
// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2;
// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface
// is able to avoid passing an in-correct scale_s parameter to the kernel layer
template <typename T>
struct has_ignore_fast_exp2_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIgnoreFastExp2), bool> &&
T::kIgnoreFastExp2>> : std::true_type
{
};
template <typename T>
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
template <typename T, typename = void>
struct has_naive_hdim_load_flag : std::false_type
{
};
template <typename T>
struct has_naive_hdim_load_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
T::kIsNaiveHDimLoad>> : std::true_type
{
};
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
// A helper struct for detecting kUseTrLoad
template <typename T, typename = void>
struct has_use_trload_flag : std::false_type
{
};
template <typename T>
struct has_use_trload_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
} // namespace detail
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel
{
@@ -77,13 +154,14 @@ struct FmhaFwdKernel
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
#if defined(__gfx950__)
static constexpr bool kIsAvailable = true;
#else
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
@@ -444,7 +522,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -897,7 +977,9 @@ struct FmhaFwdKernel
num_head_q,
nhead_ratio_qk,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale_s * ck_tile::log2e_v<>),
detail::ignore_fast_exp2_v<FmhaPipeline>
? scale_s
: static_cast<float>(scale_s * ck_tile::log2e_v<>),
#else
scale_s,
#endif
@@ -1039,6 +1121,7 @@ struct FmhaFwdKernel
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
const void* seqstart_v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -1097,6 +1180,7 @@ struct FmhaFwdKernel
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
seqstart_v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -1158,6 +1242,7 @@ struct FmhaFwdKernel
const void* seqlen_k_ptr,
const void* block_scale_seqstart_q_ptr,
const void* block_scale_seqstart_k_ptr,
const void* seqstart_v_scale_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -1216,6 +1301,7 @@ struct FmhaFwdKernel
seqlen_k_ptr,
block_scale_seqstart_q_ptr,
block_scale_seqstart_k_ptr,
seqstart_v_scale_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -1602,6 +1688,10 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
? FmhaPipeline::kQKHeaddim
: FmhaPipeline::kSubQKHeaddim;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1612,10 +1702,10 @@ struct FmhaFwdKernel
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
@@ -1634,10 +1724,21 @@ struct FmhaFwdKernel
number<1>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -1649,18 +1750,29 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
if constexpr(!kUseTrLoad)
{
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadHeadDimV>{});
};
}
else
{
@@ -1683,17 +1795,28 @@ struct FmhaFwdKernel
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
auto k_dram_window = [&]() {
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
{0, 0});
}
else
{
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{0, 0});
}
}();
auto v_dram_window = make_tile_window(
v_dram,
@@ -1843,7 +1966,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{
@@ -2826,7 +2952,10 @@ struct FmhaFwdKernel
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
if constexpr(!detail::ignore_fast_exp2_v<FmhaPipeline>)
{
slope *= ck_tile::log2e_v<>;
}
#endif
if constexpr(kHasMask)
{

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
@@ -134,7 +135,8 @@ template <typename IndexArrayType,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
bool kIsKcache,
index_t kN0,
index_t kVectorSize>
index_t kVectorSize,
bool kUseGlobalLoad_ = false>
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages,
const index_t& stride_token,
const index_t& stride_page_block,
@@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica
const index_t& thread_coord_start = coord_vec[kCoordAxis];
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
if constexpr(kIsKcache)
{
// K cache: per-token lookup
// Each token may be on a different page, so we use physical_pages[k0] for each.
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_):
//
// Case 1: kPageBlockSize >= kN0
// SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller).
// Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident).
// This function writes within-page offset only.
//
// Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_
// SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full
// 64-bit address is computed by tile_scatter_gather::load() in
// include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ +
// page_stride_elements_. This function writes within-page offset only.
//
// Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true)
// SRD base is the entire KV buffer; the only place to encode page identity
// is the voffset itself. This function writes the FULL offset:
// page * stride_page_block + within_page
// Limited to <2GB total KV bytes by 32-bit voffset hardware width.
//
// Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_
// Not emitted by codegen. Backstop static_assert in
// BlockFmhaBatchPrefillPipelineQRKSVSAsync.
constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_;
if constexpr(kPageBlockSize >= kN0)
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
// Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT)
const index_t within_page = [&]() {
if constexpr(!kIsKcache && kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
kv_offset_vec[k0] = token_idx_in_page * stride_token;
return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
kv_offset_vec[k0] =
physical_page * stride_page_block + token_idx_in_page * stride_token;
return token_idx_in_page * stride_token;
}
});
}
else // V cache
{
// V cache: use physical_pages[k0] for each token
// physical_pages was already populated correctly by load_physical_pages(), handling:
// - page_size=1: page_idx maps token_idx -> physical_page directly
// - V tile crosses pages: per-token page lookup
// - V tile in single page: lane0 lookup with broadcast to all lanes
static_for<0, kLoopCount, 1>{}([&](auto k0) {
const index_t global_token_idx =
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
}();
if constexpr(kPageBlockSize >= kN0)
{
// SRD rebasing mode: within-page offset only.
// The full page base is handled by rebasing the SRD pointer.
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = token_offset;
}
else
{
kv_offset_vec[k0] = token_idx_in_page * stride_token;
}
}
else
{
// Full global offset (original code path for ps1, ps16, etc.)
const index_t physical_page = physical_pages[k0];
const long_index_t page_base_offset =
static_cast<long_index_t>(physical_page) * stride_page_block;
if constexpr(kKVMemoryLayout ==
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
{
const index_t token_offset =
(token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) +
(token_idx_in_page % kVectorSize);
kv_offset_vec[k0] = page_base_offset + token_offset;
}
else
{
kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token;
}
}
});
}
// SRD + page_size < kN0: add page base to form complete voffset for buffer_load.
//
// 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF
// microcode format), so this branch is only reachable when total KV bytes fit in
// INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit
// global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling
// because the hardware truncates voffset regardless.
if constexpr(kNeedFullOffset)
{
kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page;
}
else
{
kv_offset_vec[k0] = within_page;
}
});
}
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
@@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
static constexpr index_t kVectorSize = Problem::kVectorSize;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
// Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V
// tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD
// buffer_load_*. The enum is named at the trait/Problem level; internally we
// derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits
// GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop.
static constexpr auto kKVLoadMode = Problem::kKVLoadMode;
static constexpr bool kUseGlobalLoad =
(kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS);
static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0),
"GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; "
"codegen should not emit this instantiation otherwise.");
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
@@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
k_dist,
k_offsets); // K DRAM tile window for
k_offsets,
bool_constant<kUseGlobalLoad>{},
page_stride_k);
if constexpr(kUseGlobalLoad)
{
k_dram_window.update_physical_pages(k_physical_pages);
}
k_dram_window.init_raw();
// SRD rebasing: move the buffer descriptor base pointer to each page's start address
// using 48-bit pointer arithmetic, so voffset only needs the small within-page offset.
// Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page).
// SRD rebasing for K: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_k_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
@@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_k;
window.set_bottom_tensor_view_data_ptr(page_ptr);
// Limit SRD num_records to one page worth of elements.
// Without this, the SRD claims validity for [page_ptr, page_ptr +
// full_buffer_size), which extends far beyond the allocated buffer when rebased to
// high pages. On gfx950, the hardware may validate the full SRD range against page
// table permissions, causing faults on freed/protected memory beyond the buffer.
window.set_bottom_tensor_view_buffer_size(page_stride_k);
window.init_raw();
}
};
// SRD rebasing for V: only for page_size >= kN0 (all threads on same page).
// For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle
// addressing.
auto rebase_v_window = [&](auto& window, index_t physical_page) {
if constexpr(kPageBlockSize >= kN0)
{
// readfirstlane: make physical_page provably wave-uniform so the
// resulting SRD lands in SGPRs (required by buffer load instructions).
physical_page = __builtin_amdgcn_readfirstlane(physical_page);
const auto* base_ptr =
v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_;
const auto* page_ptr =
base_ptr + static_cast<long_index_t>(physical_page) * page_stride_v;
window.set_bottom_tensor_view_data_ptr(page_ptr);
window.set_bottom_tensor_view_buffer_size(page_stride_v);
window.init_raw();
}
};
// Initial K SRD rebase
// Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead)
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
constexpr auto k_oob_ck = bool_constant<true>{};
@@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
kVectorSize,
kUseGlobalLoad>(v_physical_pages_k2,
stride_v,
page_stride_v,
v_coord,
v_offsets_k2,
current_seq_k);
static_for<0, V_KIterInner, 1>{}([&](auto k1) {
constexpr auto idx = number<k1.value + k2.value * V_KIterInner>{};
@@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
false,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
}
// v_offsets semantics — see the four-case addressing-strategy block above
// kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda:
// Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD.
// Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed
// by tile_scatter_gather::load() from
// physical_pages_.
// Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset):
// FULL offset (page * stride + within),
// carried in the 32-bit voffset (<2GB cap).
};
// Prefetch V physical pages early to hide buffer load latency
@@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_offsets,
number<1>{}, // HsGatherDim
number<1>{}, // NumCoord
VPageIndexYDims);
VPageIndexYDims,
bool_constant<kUseGlobalLoad>{},
page_stride_v);
if constexpr(kUseGlobalLoad)
{
v_dram_window.update_physical_pages(v_physical_pages);
}
// Initial V SRD rebase
// Initial V SRD rebase. Single source of truth: rebase_v_window's own
// `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3.
// Do not re-add an outer guard here — it would duplicate the inner check
// and drift if the lambda's gating condition ever changes.
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
// Save the *current* tile's V physical pages into v_dram_window before
// prefetch_v_physical_pages overwrites the v_physical_pages buffer with the
// *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read
// physical_pages_ from the window. Encapsulating the save+prefetch pair
// here makes the ordering invariant unmissable when a fourth prefetch site
// is added later.
auto save_and_prefetch_v_pages = [&](auto k_loop_start) {
if constexpr(kUseGlobalLoad)
v_dram_window.update_physical_pages(v_physical_pages);
prefetch_v_physical_pages(k_loop_start);
};
// prefetch K tile
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
@@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
}
// Prefetch V physical pages early - overlaps with GEMM0 computation
prefetch_v_physical_pages(number<kK1>{});
save_and_prefetch_v_pages(number<kK1>{});
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
@@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages early - overlaps with softmax computation
if constexpr(k1_loops > 1)
{
prefetch_v_physical_pages(number<2 * kK1>{});
save_and_prefetch_v_pages(number<2 * kK1>{});
}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
@@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
v_dram_window,
{0,
kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
update_v_offsets(number<2 * kK1>{});
v_dram_window.update_page_idx(v_offsets);
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
@@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
// Update V offsets using previously prefetched physical pages
update_v_offsets(number<(2 + i_k1.value) * kK1>{});
v_dram_window.update_page_idx(v_offsets);
@@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
// Prefetch V physical pages for NEXT iteration - overlaps with GEMM1
if constexpr(i_k1 + 1 < k1_loops - 1)
{
prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{});
save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{});
}
block_sync_lds();
@@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
kKVMemoryLayout,
true,
kN0,
kVectorSize>(
kVectorSize,
kUseGlobalLoad>(
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
k_dram_window.update_page_idx(k_offsets);
if constexpr(kUseGlobalLoad)
k_dram_window.update_physical_pages(k_physical_pages);
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
// After sink→window transition (i_total_loops == num_sink_loop), V window

View File

@@ -9,6 +9,52 @@
namespace ck_tile {
namespace detail {
template <typename DataType, index_t ElemPerThread>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
{
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 6 == 0)
// return 6;
if constexpr(ElemPerThread % 8 == 0)
return 8;
else if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else if constexpr(std::is_same_v<DataType, float>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 3 == 0)
// return 3;
if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else
return 1;
};
template <typename DataType,
index_t kThreadBlockSize,
index_t kHigherDimSize,
index_t kLowerDimSize>
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
{
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
return GetMaxVectorSize<DataType, ElemPerThread>();
}
} // namespace detail
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
@@ -117,6 +163,12 @@ struct BlockFmhaBatchPrefillPipelineProblem
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
"kPageBlockSize must be power of two");
// KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via
// 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the
// <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's
// existing TwoGB convention.
static constexpr auto kKVLoadMode = Traits_::kKVLoadMode;
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
static constexpr auto kKVLookupTable = Traits_::kKVLookupTable;

View File

@@ -0,0 +1,861 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using CompDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true;
static_assert(kQLoadOnce == Policy::QLoadOnce);
static_assert(sizeof(KDataType) == sizeof(VDataType) &&
alignof(KDataType) == alignof(VDataType),
"K and V share the same LDS region; their element types must have identical "
"size and alignment.");
static constexpr bool kUseN0Loop = true;
static constexpr bool kIgnoreFastExp2 = true;
static constexpr bool kIsNaiveHDimLoad = true;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kN0Sub =
BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
static_assert(Problem::kUseTrLoad == true, "Check failed!");
static constexpr bool kUseTrLoad = true;
// since this pipeline is only used by the inference path of xformers, the Dropout function is
// not well tested with the pipeline, so here we have Dropout disabled
static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)
{
return 2;
}
else if constexpr(kQKHeaddim == 64)
{
return 2;
}
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim == 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_async_whole_k_prefetch_trload";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& /* unused */,
const AttentionVariantParams& /* unused */,
const BlockIndices& /* unused */,
void* smem_ptr,
DropoutType& dropout) const
{
// xformers path does not require the pipeline to output random values for host
// verification, since a separate kernel is used to generate random values
ignore = randval_dram_block_window_tmp;
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr index_t n0_loops = kN0 / kN0Sub;
constexpr index_t k1_loops = kN0 / kK1;
// usually kN0 is 128, kN0Sub/kK1 is 32/16
static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline");
static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline");
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
static_assert(n0_loops >= NumPrefetchV, "Check failed!");
static_assert(k1_loops >= NumPrefetchV, "Check failed!");
constexpr bool kPreloadWholeNextIterationK =
Policy::template IsPreloadWholeNextIterationK<Problem>();
// This path prefetches two k_tiles for next iteration, so it has the opportunity to
// prefetch two v_tiles during Gemm0
if constexpr(!kPreloadWholeNextIterationK)
{
static_assert(NumPrefetchV >= 2);
};
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
// SaccBlockTile size is [kM0, kK1]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(seqlen_k_end <= seqlen_k_start)
{
clear_tile(o_acc);
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
};
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
auto q_tile = load_tile(q_dram_window);
using k_tile_type = decltype(load_tile(k_dram_window));
auto k_tiles = [&]() {
if constexpr(kPreloadWholeNextIterationK)
return statically_indexed_array<k_tile_type, n0_loops>{};
else
return statically_indexed_array<k_tile_type, 2>{};
}();
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
if constexpr(!kPreloadWholeNextIterationK)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
// provide partition_index for LDS tile window with so that warp_id is in vgpr
array<index_t, 2> partition_index{get_warp_id<false>(), get_lane_id()};
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kN1>{}),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
const auto f_exp = [&](CompDataType x) {
if constexpr(std::is_same_v<CompDataType, float>)
{
return __expf(x);
}
else
{
return exp(x);
}
};
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kN0>{}),
{bias_origin.at(number<0>{}), seqlen_k_start},
Policy::template MakeBiasDramTileDistribution<Problem>());
// assuming no random values need be saved, this is true when the pipeline is called from
// xformers, since we have a separate kernel to generated random values
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to pass a null_randval_dram and tile window to the BlockDropout operator to
// make it works
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
do
{
// STAGE 1, Gemm_0 ( S = Q@K )
if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64
{
if(seqlen_k_curr == seqlen_k_start) // at first iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < n0_loops - 1)
{
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
// prefetch all k_tiles for next iteration
static_for<0, n0_loops, 1>{}([&](auto ii_n0) {
k_tiles[number<ii_n0>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
else // the iteration is also the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 < n0_loops - 1)
{
k_tiles[number<i_n0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 == n0_loops - 1)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
};
}
else // at intermediate and last iteration
{
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 == 0)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
// prefetch k_tile for next iteration
k_tiles[i_n0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
else // last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
if constexpr(i_n0 == 0)
{
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
block_sync_lds();
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
};
}
}
else // only preload one unroll of K for next iteration, used when kM0=128
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0 % 2>{}]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(i_n0 < n0_loops - 2)
{
k_tiles[number<i_n0 % 2>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
if constexpr(i_n0 >= n0_loops - 2)
{
v_tiles[number<i_n0 - (n0_loops - 2)>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_n0 * kN0Sub>{},
sequence<kM0, (i_n0 + 1) * kN0Sub>{});
});
}
__builtin_amdgcn_sched_barrier(0x000000001);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
tile_elementwise_inout(
[&](auto& x, const auto y) {
x += type_convert<CompDataType>(bias_element_func(y));
},
pcomp_tile,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) *= scale_s;
position_encoding.update(pcomp_tile(i_j_idx), row, col);
});
});
}
else
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
}
move_tile_window(bias_dram_window, {0, kN0});
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
q_origin.at(number<0>{}), seqlen_k_curr, number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(pcomp_tile, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
__builtin_amdgcn_sched_barrier(0x00000001);
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m;
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
__builtin_amdgcn_sched_barrier(0);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[I0]),
partition_index);
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(kPreloadWholeNextIterationK)
{
static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
}
else
{
static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) {
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
});
};
__builtin_amdgcn_sched_barrier(0);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
});
}
else
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
});
}
});
auto rowsum_p =
block_tile_reduce<CompDataType>(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// adjust o_acc[] according to the update between m and m_old
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
l(i_idx) = rowsum_p[i_idx];
}
else
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
}
});
__builtin_amdgcn_sched_barrier(0x00000001);
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
seqlen_k_curr += kN0;
__builtin_amdgcn_sched_barrier(0x00000001);
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
if constexpr(i_k1 < k1_loops - NumPrefetchV)
{
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
};
if constexpr(i_k1 == k1_loops - NumPrefetchV)
{
if constexpr(!kPreloadWholeNextIterationK)
{
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
}
};
if constexpr(i_k1 == k1_loops - NumPrefetchV + 1)
{
if constexpr(!kPreloadWholeNextIterationK)
{
if(seqlen_k_curr < seqlen_k_end)
{
k_tiles[I1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kN0Sub, 0});
};
}
};
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]);
if constexpr(i_k1 < k1_loops - 1)
{
store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func,
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]),
partition_index);
};
});
// check whether last V-LdsBuffer overlap with first K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0)
{
__builtin_amdgcn_s_barrier();
};
} while(seqlen_k_curr < seqlen_k_end);
// store lse
if constexpr(kStoreLSE)
{
auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(m[i_idx] == -numeric<CompDataType>::infinity())
o_acc(i_j_idx) = 0.0f;
else
o_acc(i_j_idx) *= 1.0f / l[i_idx];
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
{
ignore = sink_v;
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout) const
DropoutType& dropout,
const float sink_v) const
{
ignore = sink_v;
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,

View File

@@ -57,7 +57,7 @@ struct TileFmhaShape
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
@@ -58,7 +59,9 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
kPadSeqLenK_,
kPadHeadDimQ_,
@@ -76,6 +79,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT ||
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT,
"Batch prefill only supports vectorized or linear KV cache layout.");

View File

@@ -1685,7 +1685,7 @@ struct MoeSortingMultiPhaseKernel_P0_v1
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
if(eid < kargs.num_experts)
if(eid < kargs.num_experts && eid >= 0)
{
if constexpr(Problem::LocalToken)
{

View File

@@ -0,0 +1,268 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchK
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(nIter)(I0) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I0),
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
if constexpr(kIter < KIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
b_warp_tensors[number<kIter + 1>{}] =
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
};
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
if constexpr(kIter == 0)
{
// warp GEMM
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
}
else
{
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
};
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode<MPerBlock, KPerBlock>();
return make_static_tile_distribution(a_block_dstr_encode);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static_assert(NWarp == 1, "Check failed!");
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
return c_block_dstr_encode;
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t NPerBlock = BlockGemmShape::kN>
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode<MPerBlock, NPerBlock>();
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,239 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2PrefetchN
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{0 * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter));
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{}) =
load_tile(b_warp_windows(number<nIter + 1>{})(kIter));
};
__builtin_amdgcn_sched_barrier(0x00000001);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,243 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id<false>() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// construct from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
constexpr auto b_warp_dstr_encode =
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
BDataType>::TransposedDstrEncode{};
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
make_static_tile_distribution(b_warp_dstr_encode));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
statically_indexed_array<b_warp_tensor_type, NIterPerWarp> b_warp_tensors;
// read B warp tensor from B Block window
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter));
__builtin_amdgcn_sched_barrier(0);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
// read B warp tensor from B Block window
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{}) =
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
};
__builtin_amdgcn_sched_barrier(0);
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile