intial commit

This commit is contained in:
Tianxing Wu
2025-10-06 13:02:38 +00:00
parent ef43078788
commit e54cb5a713
12 changed files with 4719 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionBiasEnum
{
NO_BIAS = 0,
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
ALIBI = 2, // bias computed with position encoding, applied after scale
};
template <BlockAttentionBiasEnum>
struct BlockAttentionBiasEnumToStr;
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
{
static constexpr const char* name = "";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
{
static constexpr const char* name = "bias";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
{
static constexpr const char* name = "alibi";
};
} // namespace ck_tile

View File

@@ -0,0 +1,654 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and
// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random
// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host
// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp).
//
// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of
// random numbers (ph_subsequence).
// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and
// ph_offset).
// This means that subsequences are non-overlapping, reproducible and independent of mask or window.
//
// There are 3 modes (all produce the same results):
// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates
// the entire 32x32 tile (64 * 16 = 32 * 32).
// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4
// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock >
// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions
// are needed for generating a 32x32 tile.
// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2
// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp *
// WG::kM one warp can generate two 16x16 tiles.
namespace detail {
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
constexpr index_t philox_per_tile = 64;
} // namespace detail
struct NullBlockDropout
{
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
struct BlockDropout
{
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_,
bool is_store_randval_)
: ph_seed(amd_wave_read_first_lane(seed)),
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
detail::philox_per_tile)),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
is_store_randval(is_store_randval_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
// The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution,
// because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
constexpr auto randval_block_inner_part_dstr_encoding =
typename WarpGemmDispatcher<typename WG::ADataType,
typename WG::BDataType,
typename WG::CDataType,
WG::kM,
WG::kN,
WG::kK,
false,
IsWG32>::CWarpDstrEncoding{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
const auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
auto generate_randval = [&](auto i_m0, auto i_n0) {
// Generate random numbers
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
if constexpr(IsWG32)
{
// Generate the whole 32x32 tile at once (each tile consists of random numbers taken
// from a separate subsequence of Philox)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
const index_t ph_offset = get_lane_id();
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
else
{
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
// MIterPerWarp is equal to 1 or 2)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
const index_t subtile_m0 = wg_m0 % 2;
if constexpr(get_warp_size() == 32)
{
const index_t ph_offset = (get_lane_id() & 15) +
(((get_lane_id() >> 4) & 1) << 5) +
((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
}
else
{
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
ph.get_random_4x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
}
}
}
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// Transpose randval using LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
const auto randval = load_tile(randval_lds_read_window);
block_sync_lds();
return randval;
};
if(is_store_randval)
{
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
const auto randval = generate_randval(i_m0, i_n0);
// save to Global
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
const auto randval = generate_randval(i_m0, i_n0);
// Drop values of P based on the generated probabilities
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>()>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0,
idx1.impl_.template at<1>(),
idx1.impl_.template at<2>()>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
});
});
}
const unsigned long long ph_seed;
const unsigned long long ph_head_offset;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be
// replaced with NullBlockDropout. This requires changes in xformers and other libs.
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph_seed(amd_wave_read_first_lane(seed)),
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
detail::philox_per_tile)),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 0>>{};
constexpr auto randval_block_inner_part_dstr_encoding =
typename WarpGemmDispatcher<typename WG::ADataType,
typename WG::BDataType,
typename WG::CDataType,
WG::kM,
WG::kN,
WG::kK,
false,
IsWG32>::CWarpDstrEncoding{};
static_assert(
std::is_same_v<remove_cvref_t<decltype(randval_block_inner_part_dstr_encoding)>,
typename WG::CWarpDstrEncoding>);
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
auto generate_randval = [&](auto i_m0, auto i_n0) {
// Generate random numbers
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
if constexpr(IsWG32)
{
// Generate the whole 32x32 tile at once (each tile consists of random numbers
// taken from a separate subsequence of Philox)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
const index_t ph_offset = get_lane_id();
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
else
{
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
// MIterPerWarp is equal to 1 or 2)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
const index_t subtile_m0 = wg_m0 % 2;
if constexpr(get_warp_size() == 32)
{
const index_t ph_offset = (get_lane_id() & 15) +
(((get_lane_id() >> 4) & 1) << 5) +
((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
}
else
{
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
ph.get_random_4x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
}
}
}
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
return randval_dist_generated;
};
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
const auto randval = generate_randval(i_m0, i_n0);
// Drop values of P based on the generated probabilities, negative sign is used to
// distinguish such values later in bwd pipeline.
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>(),
idx0.impl_.template at<1>(),
idx0.impl_.template at<2>()>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx]
: -p_compute[p_idx];
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
}
}
const unsigned long long ph_seed;
const unsigned long long ph_head_offset;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
};
} // namespace ck_tile

View File

@@ -0,0 +1,642 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum struct GenericAttentionMaskEnum
{
NO_MASK = 0,
// below enum could be causal, or sliding window
MASK_FROM_TOP_LEFT = 1,
MASK_FROM_BOTTOM_RIGHT = 2,
// this enum maybe not used by xformer/FA, since it's hard to
// specify left/right window for varlen case. put it here for
// debug purpose
MASK_GENERIC,
};
// clang-format off
/* generic Attention Mask Coordinate
use x(horizontal axis), y(vertical axis) to describe mask.
top-left corner is origin
x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
l=7,-1/r=0(tl) l=7,-1/r=0(br)
x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
* 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
* * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
* * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
x=4/y=-1 x=6/y=-1 x=8/y=-1
* * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
* * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
* * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
* * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
* * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
* * * * * * * * 1 * * * * * * *
* * * * * * * * 1 1 * * 1 * * *
* * * * * * * * 1 1 1 * 1 1 * *
1 * * * * * * * 1 1 1 1 1 1 1 *
1 1 * * * * * * 1 1 1 1 1 1 1 1
Validations:
x + y > 1 (x + y >= 2)
Note:
y = seq_q, x = 1 -> top-left
y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
y < seq_q, x < seq_k -> local-attn
y = seq_q, x = seq_k -> no mask
*/
namespace impl {
template <bool IsMasking_, bool IsLocal_> struct MaskName;
template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
}
// clang-format on
template <bool IsMasking_ = true, bool IsLocal_ = false>
struct GenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
// else only upper-right could have mask
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
: GenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
if constexpr(IsLocal)
{
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}
else
{
return 0;
}
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
{
return i_x >= x_total;
}
else
{
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1;
index_t x_end = min(i_y + x, x_total);
if constexpr(IsLocal)
{
return i_x < x_start || i_x >= x_end;
}
else
{
return i_x >= x_end || i_y >= y_total;
}
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
{
if constexpr(!IsMasking)
{
// TODO: no need to check begin
return (i_tile_left + TileWidth) > x_total;
}
else
{
if constexpr(IsLocal)
{
// check top-right corner > x or left-borrom corner < x
index_t i_tile_right = i_tile_left + TileWidth;
index_t i_tile_bottom = i_tile_top + TileHeight;
index_t x_end = min(i_tile_top + x, x_total);
bool top_right_edge = i_tile_right > (i_tile_top + x);
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
bool is_partial_out_of_bound =
i_tile_right > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
}
else
{
// only need to check top-right corner > x
index_t i_tile_right = i_tile_left + TileWidth;
index_t x_end = min(i_tile_top + x, x_total);
bool top_right_edge = i_tile_right > x_end;
return top_right_edge;
}
}
}
private:
index_t y, x;
index_t y_total, x_total;
};
// clang-format off
namespace impl {
template <bool IsMasking_> struct SimplifiedMaskName;
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
}
// clang-format on
// this version only have 2 variation: masking and non-masking
// This is more friendly to codegen (e.g. need generate less kernel)
// ... with the trade-off that may have more instruction in causal mode
template <bool IsMasking_ = true>
struct SimplifiedGenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split;
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
return i_x >= x_total;
}
else
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
// TODO: no need to check begin
return (i_x + TileWidth) > x_total;
}
else
{
// check top-right corner > x or left-borrom corner < x
index_t i_x_end = i_x + TileWidth;
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge;
}
}
private:
index_t y, x;
index_t y_total, x_total;
};
// clang-format off
namespace impl {
template <bool IsMasking_> struct SimplifiedRatioMaskName;
template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
}
// clang-format on
// this version is used for cases that the step length of y-direction changes greater than one. It
// means that the mask is not a regular triangular matrix.
// clang-format off
/* y_ratio is used to describe the step length of y-direction changes
in certain performance optimization scenarios like merging seqlen
and qk_head_ratio, for example:
x=1/y=6/y_ratio=2(top-left)
1 * * * * * * *
1 * * * * * * *
1 1 * * * * * *
1 1 * * * * * *
1 1 1 * * * * *
1 1 1 * * * * *
*/
// clang-format on
template <bool IsMasking_ = true>
struct SimplifiedRatioAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
{
}
CK_TILE_HOST_DEVICE
SimplifiedRatioAttentionMask(
index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
: SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
/*x_=*/x_,
/*y_total_=*/y_total_,
/*x_total_=*/x_total_,
/*y_real_=*/y_real_,
/*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
/*y_ratio_mdiv_=*/y_ratio_mdiv_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedRatioAttentionMask(index_t y_,
index_t x_,
index_t y_total_,
index_t x_total_,
index_t y_real_,
index_t y_ratio_,
mdiv y_ratio_mdiv_)
: y(y_),
x(x_),
y_total(y_total_),
x_total(x_total_),
y_real(y_real_),
y_ratio(y_ratio_),
y_ratio_mdiv(y_ratio_mdiv_)
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = -y_real +
static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
1;
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
uint32_t y_offset = i_y + YTile - 1;
index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, y_total);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();
return ck_tile::make_tuple(y_start, y_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
{
return i_x >= x_total;
}
else
{
index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
index_t x_start = -y_real + x_tmp + 1;
index_t x_end = min(x_tmp + x,
x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
return (i_x + TileWidth) > x_total;
}
else
{
// check top-right corner > x or left-borrom corner < x
index_t i_x_end = i_x + TileWidth;
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);
uint32_t y_tmp = static_cast<uint32_t>(i_y);
bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
x_total); // consider right pad
bool bottom_left_edge =
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
return top_right_edge || bottom_left_edge;
}
}
private:
index_t y, x;
index_t y_total, x_total;
// y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
index_t y_real;
index_t y_ratio;
mdiv y_ratio_mdiv;
};
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
// TODO: below should all use sgpr arithmetic
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
left_size = left_size < 0 ? left_size_tmp : left_size;
right_size = right_size < 0 ? right_size_tmp : right_size;
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
}
} // namespace ck_tile

View File

@@ -0,0 +1,205 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace ck_tile {
enum struct PositionEncodingEnum
{
NO = 0,
ALIBI = 1,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT(but negative):
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum struct AlibiMode
{
VERTICAL = 0,
FROM_TOP_LEFT = 1, // keep sync with mask enum
FROM_BOTTOM_RIGHT = 2,
};
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
struct Alibi
{
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
index_t y_total_,
index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL)
{
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
}();
shift_right_down = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
}();
mode = mode_;
}
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
if constexpr(LogMaxSadOprndSize <= 16)
{
return sad_u16(
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
}
return sad_u32(x, y, acc);
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{
if constexpr(RowMajor)
{
// at least 3 instructions per row
index_t current_zero_point =
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(col_idx + shift_left_up),
0));
pixel += slope * position;
}
else
{
// at least 3 instructions per col;
index_t current_zero_point = mode == AlibiMode::VERTICAL
? row_idx + col_idx + shift_right_down
: col_idx + shift_right_down;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(row_idx + shift_left_up),
0));
pixel += slope * position;
}
}
DataType slope; // float?
index_t shift_left_up; // always possitive
index_t shift_right_down; // always possitive
AlibiMode mode;
};
template <typename DataType>
struct EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size,
index_t window_right_size,
index_t y_total,
index_t x_total,
GenericAttentionMaskEnum mask_enum)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool is_causal = window_left_size < 0 && window_right_size == 0;
AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template <typename DataType>
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
{
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
float start = std::powf(
static_cast<float>(2),
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
std::vector<DataType> rtn;
for(auto i = 0; i < n; i++)
{
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
}
return rtn;
};
if(is_power_of_two_integer(nheads))
{
// power of 2 calculation
return get_slopes_power_of_2(nheads);
}
else
{
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
auto v0 = get_slopes_power_of_2(closest_power_of_2);
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
std::vector<DataType> sliced;
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
{
if(i % 2 == 0)
sliced.push_back(vec[i]);
}
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
return sliced_2;
}(v1, nheads - closest_power_of_2);
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
return v0;
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
struct BlockRotaryEmbedding
{
template <typename DistributedTensor,
typename OtherDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
OtherDramBlockWindow other_window,
RotaryCosDramBlockWindow rotary_cos_window,
RotarySinDramBlockWindow rotary_sin_window,
index_t rotary_dim,
index_t thread_end)
{
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
if(thread_end <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
});
}
}
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
if(thread_end <= rotary_dim)
{
const bool is_left = (thread_end <= (rotary_dim / 2));
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto other_tile = load_tile(other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
tile.thread_buf_[idx] =
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,358 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
// assume that we have only 1 page-block/tensor view
template <typename TensorView>
struct TrivialPageBlockNavigator
{
using DataType = typename TensorView::DataType;
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
: tensor_view(tensor_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
return make_tuple(/*block_index=*/0,
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE constexpr auto
make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
return make_tuple(
/*block_index=*/0,
ck_tile::make_tile_window(
tensor_view, window_lengths, window_origin, tile_distribution));
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
return /*block_index=*/0;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
index_t /*id*/) const
{
ck_tile::move_tile_window(tile_window, step);
return 0;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
prefetch_table_id(index_t /*block_index*/,
TileWindow /*tile_window*/,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
{
return -1;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
}
private:
TensorView tensor_view;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template <typename DataType_, index_t VirtualDim, typename TensorView>
struct PageBlockNavigator
{
using DataType = DataType_;
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
long_index_t block_stride_,
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorView& complete_view_,
const TensorView& last_view_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
complete_view(complete_view_),
last_view(last_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin,
tile_distribution);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
return new_block_index;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
index_t id) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
if(id >= 0)
tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
fixed_offset);
else
tile_window.set_bottom_tensor_view_data_ptr(nullptr);
return new_block_index;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
prefetch_table_id(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
auto local_tile_window = tile_window; // not affect origin window
ck_tile::move_tile_window(local_tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, local_tile_window.get_window_origin());
const index_t new_block_index = get_block_index(global_window_origin);
if(new_block_index < num_blocks)
{
return physical_block_indices[new_block_index];
}
else
{
return -1;
}
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
const TileWindow& tile_window) const
{
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE void
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
{
const multi_index<2> step = [&]() {
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
if constexpr(VirtualDim == 0)
{
return make_multi_index(origin_diff, 0);
}
else
{
return make_multi_index(0, origin_diff);
}
}();
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(tile_window.get_window_origin() + step);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
}
CK_TILE_HOST_DEVICE WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin) const
{
if constexpr(VirtualDim == 0)
{
const index_t length = global_window_origin.at(number<0>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(length - page_block_size * num_complete_blocks,
global_window_origin.at(number<1>{}));
}
else
{
const index_t length = global_window_origin.at(number<1>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(global_window_origin.at(number<0>{}),
length - page_block_size * num_complete_blocks);
}
}
CK_TILE_HOST_DEVICE WindowOrigin
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
{
if constexpr(VirtualDim == 0)
{
return make_multi_index(block_index * page_block_size +
local_window_origin.at(number<0>{}),
local_window_origin.at(number<1>{}));
}
else
{
return make_multi_index(local_window_origin.at(number<0>{}),
block_index * page_block_size +
local_window_origin.at(number<1>{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
if(block_index < num_blocks)
{
return physical_blocks + physical_block_indices[block_index] * block_stride +
fixed_offset;
}
else
{
return nullptr;
}
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
{
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
}
DataType* physical_blocks;
long_index_t block_stride;
long_index_t fixed_offset;
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorView complete_view;
TensorView last_view;
};
template <typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
{
return TrivialPageBlockNavigator<TensorView>(tensor_view);
}
template <typename DataType, index_t VirtualDim, typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorView& complete_view,
const TensorView& last_view)
{
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
complete_view,
last_view);
}
} // namespace ck_tile

View File

@@ -0,0 +1,302 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/numeric/type_convert.hpp>
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
#endif
#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
#endif
namespace ck_tile {
namespace internal {
__device__ inline float
exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
{
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
/// NOTICE: Make sure softmax_scale is stored in SGPR
float result, numerator, denominator;
asm volatile(
"v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
"v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
"v_rcp_f32_e32 %[denominator], %[denominator]\n"
"v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
"v_mul_f32_e32 %[result], %[numerator], %[denominator]"
: [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
: [softmax_scale] "s"(softmax_scale),
[logits] "v"(logits),
[logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
return result;
#else
return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
#endif
}
} // namespace internal
template <typename ImplMask>
struct StandardAttentionParams
{
__device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
: impl_mask(impl_mask_), sm_scale(sm_scale_)
{
}
const ImplMask& impl_mask;
float sm_scale;
};
template <typename ImplMask, bool UseExp2 = false>
struct LogitsSoftCapParams
{
__device__
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
{
if(0.f < logits_soft_cap)
{
logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
}
else
{
logits_soft_cap_rcp = 0.f;
}
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
__host__
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
{
if(0.f < logits_soft_cap)
{
logits_soft_cap_rcp = 1.f / logits_soft_cap;
}
else
{
logits_soft_cap_rcp = 0.f;
}
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
__device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
float sm_scale_,
float logits_soft_cap_,
float logits_soft_cap_rcp_)
: impl_mask(impl_mask_),
sm_scale(sm_scale_),
logits_soft_cap(logits_soft_cap_),
logits_soft_cap_rcp(logits_soft_cap_rcp_)
{
// move computation here to prevent compiler from generating inefficient instruction
// sequence
if constexpr(UseExp2)
{
logits_soft_cap = log2e_v<float> * logits_soft_cap;
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
}
}
const ImplMask& impl_mask;
float sm_scale;
float logits_soft_cap;
float logits_soft_cap_rcp;
};
struct StandardAttention
{
__device__ __host__ StandardAttention() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
return type_convert<float>(q) * params.sm_scale;
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return logits;
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
template <bool UseExp2 = false>
struct LogitsSoftCap
{
__device__ __host__ LogitsSoftCap() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
if constexpr(UseExp2)
{
return q;
}
else
{
return type_convert<float>(q) * params.sm_scale;
}
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform(const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
if constexpr(UseExp2)
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return internal::exp2_soft_sign_impl(
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
#endif
}
else
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return type_convert<float>(logits) *
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
constexpr uint32_t CUSTOM_MASK = 1U;
constexpr uint32_t SLIDING_WINDOW = 2U;
constexpr uint32_t LOGITS_SOFT_CAP = 4U;
constexpr uint32_t ALIBI = 8U;
template <uint32_t VARIANT_CODE, bool UseExp2 = false>
struct ComposedAttention
{
static constexpr bool use_exp2 = UseExp2;
static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
__device__ __host__ ComposedAttention() = default;
template <typename Params, typename T>
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
{
if constexpr(use_logits_soft_cap && UseExp2)
{
return q;
}
return type_convert<float>(q) * params.sm_scale;
}
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
/// qo_idx/kv_idx.
template <typename Params, typename T>
__device__ __forceinline__ T LogitsTransform(const Params& params,
T logits,
[[maybe_unused]] uint32_t batch_idx,
/*uint32_t qo_idx, uint32_t kv_idx,*/
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
if constexpr(use_logits_soft_cap)
{
if constexpr(UseExp2)
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return internal::exp2_soft_sign_impl(
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
#endif
}
else
{
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
return params.logits_soft_cap *
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
return type_convert<float>(logits) *
rcp<float>(1.f +
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
#endif
}
}
return logits;
}
template <typename Params>
__device__ __forceinline__ bool LogitsMask(const Params& params,
[[maybe_unused]] uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
[[maybe_unused]] uint32_t qo_head_idx,
[[maybe_unused]] uint32_t kv_head_idx) const
{
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,450 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <type_traits>
#include <utility>
namespace ck_tile {
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdV3Kernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
// The attention is default causal
struct UnifiedAttentionCommonKargs
{
const void* q_ptr;
const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
void* o_ptr;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t num_queries_per_kv;
// scales
float scale_s;
float scale;
float scale_k;
float scale_v;
float scale_out;
ck_tile::index_t total_num_q_blocks;
ck_tile::index_t query_stride_0;
ck_tile::index_t query_stride_1;
ck_tile::index_t stride_k_cache_0;
ck_tile::index_t stride_k_cache_1;
ck_tile::index_t stride_k_cache_2;
ck_tile::index_t stride_k_cache_3;
ck_tile::index_t stride_v_cache_0;
ck_tile::index_t stride_v_cache_1;
ck_tile::index_t stride_v_cache_2;
ck_tile::index_t stride_v_cache_3;
ck_tile::index_t output_stride_0;
ck_tile::index_t output_stride_1;
ck_tile::index_t HEAD_SIZE_PADDED;
};
struct UnifiedAttentionVarlenKargs
{
const int32_t* block_tables_ptr;
const int32_t* seq_lens_ptr; // seq len in each batch
const int32_t* query_start_len_ptr; // [num_seqs+1]
ck_tile::index_t num_seqs; // number of batches for q
ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent????
ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent????
};
struct Kargs {
UnifiedAttentionCommonKargs unifiedAttentionCommonKargs;
UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs;
};
// using Kargs = FmhaFwdGroupModeKargs;
CK_TILE_HOST static constexpr Kargs MakeKargs(
const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
void* o_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t num_queries_per_kv,
float scale_s,
float scale,
float scale_k,
float scale_v,
float scale_out,
ck_tile::index_t total_num_q_blocks,
ck_tile::index_t query_stride_0,
ck_tile::index_t query_stride_1,
ck_tile::index_t stride_k_cache_0,
ck_tile::index_t stride_k_cache_1,
ck_tile::index_t stride_k_cache_2,
ck_tile::index_t stride_k_cache_3,
ck_tile::index_t stride_v_cache_0,
ck_tile::index_t stride_v_cache_1,
ck_tile::index_t stride_v_cache_2,
ck_tile::index_t stride_v_cache_3,
ck_tile::index_t output_stride_0,
ck_tile::index_t output_stride_1,
const int32_t* block_tables_ptr,
const int32_t* seq_lens_ptr,
const int32_t* query_start_len_ptr,
ck_tile::index_t num_seqs,
ck_tile::index_t BLOCK_SIZE,
ck_tile::index_t BLOCK_Q
)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
hdim_q,
hdim_v,
num_head_q,
num_queries_per_kv,
static_cast<float>(scale_s * ck_tile::log2e_v<>),
scale,
scale_k,
scale_v,
scale_out,
total_num_q_blocks,
query_stride_0,
query_stride_1,
stride_k_cache_0,
stride_k_cache_1,
stride_k_cache_2,
stride_k_cache_3,
stride_v_cache_0,
stride_v_cache_1,
stride_v_cache_2,
stride_v_cache_3,
output_stride_0,
output_stride_1},
{
block_tables_ptr,
seq_lens_ptr,
query_start_len_ptr,
num_seqs,
BLOCK_SIZE,
BLOCK_Q,
}};
return kargs;
}
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
ck_tile::index_t total_num_q_blocks)
{
return dim3(num_kv_heads * total_num_q_blocks, 0, 0);
}
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
// ck_tile::index_t total_num_q_blocks)
// {
// // TODO: fix 3D grid
// return dim2(num_kv_heads, total_num_q_blocks);
// }
// Binary search to find the sequence index for a given target index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
ck_tile::index_t target_idx,
ck_tile::index_t num_seqs,
ck_tile::index_t BLOCK_Q,
bool use_q_block_mode)
{
ck_tile::index_t left = 0;
ck_tile::index_t right = num_seqs;
while (left < right)
{
ck_tile::index_t mid = (left + right) / 2;
ck_tile::index_t val = query_start_len_ptr[mid];
ck_tile::index_t mid_val = use_q_block_mode ? (val / BLOCK_Q + mid) : val;
if (mid_val <= target_idx)
{
left = mid + 1;
}
else
{
right = mid;
}
}
return left - 1;
}
CK_TILE_DEVICE static constexpr auto
RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs)
{
using namespace ck_tile;
constexpr index_t NUM_XCDS = 8;
const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks *
(kargs.unifiedAttentionCommonKargs.num_head_q);
// Number of pids per XCD in the new arrangement
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
index_t tall_xcds = GRID_MN % NUM_XCDS;
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
// Compute current XCD and local pid within the XCD
const index_t xcd = pid % NUM_XCDS;
const index_t local_pid = pid / NUM_XCDS;
// Calculate new pid based on the new grouping
index_t remapped_pid = 0; // Initialize to avoid constexpr error
if(xcd < tall_xcds)
{
remapped_pid = xcd * pids_per_xcd + local_pid;
}
else
{
remapped_pid = tall_xcds * pids_per_xcd +
(xcd - tall_xcds) * (pids_per_xcd - 1) +
local_pid;
}
return remapped_pid;
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs)
{
using namespace ck_tile;
ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks;
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
// FmhaPipeline::kN1);
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
return ck_tile::make_tuple(i_tile_m, i_tile_n);
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
using namespace ck_tile;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t pid = blockIdx.x;
pid = RemapTileIndices(pid, kargs);
// divide problem
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
const index_t seq_idx = find_seq_idx(
kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true
); // which batch
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]);
const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
// TODO check if we get the block size info from pipeline
if (q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q >= cur_batch_query_len) {
return;
}
const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q;
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) +
static_cast<long_index_t>(kv_head_idx) * kargs.unifiedAttentionCommonKargs.num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1 +
static_cast<long_index_t>(cur_batch_in_all_start_index) * kargs.unifiedAttentionCommonKargs.query_stride_0;
// const KDataType* k_ptr =
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
// batch_offset_k;
// const VDataType* v_ptr =
// reinterpret_cast<const VDataType*>(kargs.v_ptr) +
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
// batch_offset_v;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto k_dram = [&]() {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}();
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
{i_m0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
{0, i_n1});
// lse
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<1>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
smem_ptr);
}();
// O DRAM and O DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,603 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
namespace ck_tile {
struct UnifiedAttentionPipelineDefaultPolicy
{
static constexpr ck_tile::index_t NumWarpPerGroup = 4;
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
NumWarpPerGroup * ck_tile::get_warp_size();
// TODO: GetAlignment*() currently didn't consider if need padding or not
// so in pipeline still need check padding requirement
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
{
using namespace ck_tile;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
{
using namespace ck_tile;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
#if defined(__gfx950__)
constexpr index_t MaxReadSizeInBytes = 16;
#else
constexpr index_t MaxReadSizeInBytes = 4;
#endif
return MaxReadSizeInBytes / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
{
using namespace ck_tile;
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr index_t N0 = NumIssues;
constexpr index_t N1 = LaneGroups;
constexpr index_t N2 = NumWarps;
constexpr index_t K0 = LanesPerK;
constexpr index_t K1 = KVector;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution()
{
using namespace ck_tile;
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
// compute the endcoding before transpose
constexpr auto v_block_dstr =
make_static_tile_distribution(typename InputTileDistributionTraits<
decltype(v_block_dstr_encode),
typename Problem::VDataType>::TransposedDstrEncode{});
return v_block_dstr;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kM0,
Problem::UnifiedAttentionShape::kN0,
Problem::UnifiedAttentionShape::kK0>,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here
return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{};
}
}();
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
decltype(warp_gemm),
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
{
using namespace ck_tile;
using GemmProblem =
BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kM0,
Problem::UnifiedAttentionShape::kN1,
Problem::UnifiedAttentionShape::kK1>,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
using WarpGemm = WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Double>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
WarpGemm,
GemmLoopOrder::MNK>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return k_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
using namespace ck_tile;
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
// this function assume K/V can share smem
constexpr index_t SingleKSize = [&]() {
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack;
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector;
constexpr index_t LaneGroups = WarpSize / LanesPerK;
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
}();
constexpr index_t SingleVSize = [&]() {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kKPerBlock % kKPack == 0);
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
}();
return max(SingleKSize, SingleVSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
// Optimize this for lds_read speed
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK =
kKPerBlock / KVector; // how many lane (within a wave) to load K
constexpr index_t LaneGroups =
WarpSize /
LanesPerK; // how many groups (within a wave), they may load different N, but same K
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return v_lds_block_desc_issues_warps_lanes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
{
using namespace ck_tile;
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<WarpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
using namespace ck_tile;
static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t k_element_space_size =
MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
constexpr index_t v_element_space_size =
MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
GetSingleSmemElementSpaceSize<Problem>());
/// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() &
/// MakeVLdsBlockDescriptor()
static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
constexpr index_t kv_element_space_size_in_bytes =
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
return kv_element_space_size_in_bytes;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 4 * GetSmemSizeKV<Problem>();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,42 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaPipelineEnum
{
QRKSVS = 0,
QRKSVS_ASYNC,
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
};
template <BlockFmhaPipelineEnum>
struct BlockFmhaPipelineEnumToStr;
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
{
static constexpr const char* name = "qs";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
{
static constexpr const char* name = "qr_async_trload";
};
} // namespace ck_tile

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
namespace ck_tile {
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
typename SaccDataType_,
typename SMPLComputeDataType_,
typename BiasDataType_,
typename RandValOutputDataType_,
typename PDataType_,
typename OaccDataType_,
typename ODataType_,
typename UnifiedAttentionShape_,
typename Traits_>
struct UnifiedAttentionPipelineProblem
{
// TODO kM0 and KN1??
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
// first gemm accumulation dtype
using SaccDataType = remove_cvref_t<SaccDataType_>;
// Softmax dtype
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
// data type for A matrix of second gemm
using PDataType = remove_cvref_t<PDataType_>;
// data type for second gemm accumulation
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
}