mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
intial commit
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionBiasEnum
|
||||
{
|
||||
NO_BIAS = 0,
|
||||
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
|
||||
ALIBI = 2, // bias computed with position encoding, applied after scale
|
||||
};
|
||||
|
||||
template <BlockAttentionBiasEnum>
|
||||
struct BlockAttentionBiasEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
|
||||
{
|
||||
static constexpr const char* name = "bias";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
|
||||
{
|
||||
static constexpr const char* name = "alibi";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
654
include/ck_tile/ops/unified_attention/block/block_dropout.hpp
Normal file
654
include/ck_tile/ops/unified_attention/block/block_dropout.hpp
Normal file
@@ -0,0 +1,654 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and
|
||||
// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random
|
||||
// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host
|
||||
// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp).
|
||||
//
|
||||
// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of
|
||||
// random numbers (ph_subsequence).
|
||||
// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and
|
||||
// ph_offset).
|
||||
// This means that subsequences are non-overlapping, reproducible and independent of mask or window.
|
||||
//
|
||||
// There are 3 modes (all produce the same results):
|
||||
// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates
|
||||
// the entire 32x32 tile (64 * 16 = 32 * 32).
|
||||
// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4
|
||||
// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock >
|
||||
// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions
|
||||
// are needed for generating a 32x32 tile.
|
||||
// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2
|
||||
// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp *
|
||||
// WG::kM one warp can generate two 16x16 tiles.
|
||||
|
||||
namespace detail {
|
||||
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
|
||||
constexpr index_t philox_per_tile = 64;
|
||||
} // namespace detail
|
||||
|
||||
struct NullBlockDropout
|
||||
{
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
(void)randval_dram_block_window_tmp;
|
||||
(void)seqlen_qk_start;
|
||||
|
||||
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
|
||||
}
|
||||
};
|
||||
|
||||
struct BlockDropout
|
||||
{
|
||||
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch,
|
||||
index_t i_head,
|
||||
index_t nheads,
|
||||
unsigned long long seed,
|
||||
unsigned long long offset,
|
||||
float rp_undrop_,
|
||||
uint8_t p_undrop_in_uint8_t_,
|
||||
bool is_store_randval_)
|
||||
: ph_seed(amd_wave_read_first_lane(seed)),
|
||||
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
|
||||
detail::philox_per_tile)),
|
||||
rp_undrop(rp_undrop_),
|
||||
p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
|
||||
is_store_randval(is_store_randval_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
constexpr index_t kN1 = 8;
|
||||
constexpr index_t kN0 = kNPerStep / kN1;
|
||||
|
||||
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
|
||||
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
|
||||
number<kN1>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
|
||||
randval_lds_block_desc_0,
|
||||
ck_tile::make_tuple(
|
||||
make_pass_through_transform(number<kMPerStep>{}),
|
||||
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
|
||||
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return randval_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
// The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution,
|
||||
// because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{};
|
||||
|
||||
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
|
||||
constexpr auto randval_block_inner_part_dstr_encoding =
|
||||
typename WarpGemmDispatcher<typename WG::ADataType,
|
||||
typename WG::BDataType,
|
||||
typename WG::CDataType,
|
||||
WG::kM,
|
||||
WG::kN,
|
||||
WG::kK,
|
||||
false,
|
||||
IsWG32>::CWarpDstrEncoding{};
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
randval_block_inner_part_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
typename WG::CWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename PComputeDataType,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// randval tile in LDS
|
||||
auto randval_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
|
||||
|
||||
auto randval_lds_window = make_tile_window(
|
||||
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
|
||||
|
||||
// register distribute
|
||||
auto randval_dist_generated =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
|
||||
const auto randval_lds_read_window =
|
||||
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
|
||||
randval_lds_window.get_window_lengths(),
|
||||
randval_lds_window.get_window_origin(),
|
||||
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
|
||||
|
||||
const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
auto generate_randval = [&](auto i_m0, auto i_n0) {
|
||||
// Generate random numbers
|
||||
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
|
||||
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
|
||||
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
|
||||
if constexpr(IsWG32)
|
||||
{
|
||||
// Generate the whole 32x32 tile at once (each tile consists of random numbers taken
|
||||
// from a separate subsequence of Philox)
|
||||
const unsigned long long ph_subsequence =
|
||||
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
|
||||
const index_t ph_offset = get_lane_id();
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
|
||||
// MIterPerWarp is equal to 1 or 2)
|
||||
const unsigned long long ph_subsequence =
|
||||
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
|
||||
const index_t subtile_m0 = wg_m0 % 2;
|
||||
if constexpr(get_warp_size() == 32)
|
||||
{
|
||||
const index_t ph_offset = (get_lane_id() & 15) +
|
||||
(((get_lane_id() >> 4) & 1) << 5) +
|
||||
((wg_n0 % 2) << 4);
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
if constexpr(MIterPerWarp == 1)
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
|
||||
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
if constexpr(MIterPerWarp == 1)
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
|
||||
ph.get_random_4x8(
|
||||
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto randval_dist_generated_spans =
|
||||
decltype(randval_dist_generated)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
|
||||
});
|
||||
});
|
||||
// Transpose randval using LDS
|
||||
store_tile(randval_lds_window, randval_dist_generated);
|
||||
block_sync_lds();
|
||||
const auto randval = load_tile(randval_lds_read_window);
|
||||
block_sync_lds();
|
||||
return randval;
|
||||
};
|
||||
|
||||
if(is_store_randval)
|
||||
{
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
const auto randval = generate_randval(i_m0, i_n0);
|
||||
// save to Global
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
const auto randval = generate_randval(i_m0, i_n0);
|
||||
// Drop values of P based on the generated probabilities
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0 * MIterPerWarp +
|
||||
idx0.impl_.template at<0>()>{};
|
||||
constexpr auto p_idx1 =
|
||||
tile_distributed_index<i_n0,
|
||||
idx1.impl_.template at<1>(),
|
||||
idx1.impl_.template at<2>()>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx] * rp_undrop
|
||||
: PComputeDataType(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const unsigned long long ph_seed;
|
||||
const unsigned long long ph_head_offset;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
const bool is_store_randval;
|
||||
};
|
||||
|
||||
// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be
|
||||
// replaced with NullBlockDropout. This requires changes in xformers and other libs.
|
||||
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd;
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = false;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
(void)randval_dram_block_window_tmp;
|
||||
(void)seqlen_qk_start;
|
||||
|
||||
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <bool IsWG32_, bool IsStoreRandval_>
|
||||
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
|
||||
{
|
||||
static constexpr bool IsDropout = true;
|
||||
static constexpr bool IsStoreRandval = IsStoreRandval_;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
|
||||
index_t i_head,
|
||||
index_t nheads,
|
||||
unsigned long long seed,
|
||||
unsigned long long offset,
|
||||
float rp_undrop_,
|
||||
uint8_t p_undrop_in_uint8_t_)
|
||||
: ph_seed(amd_wave_read_first_lane(seed)),
|
||||
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
|
||||
detail::philox_per_tile)),
|
||||
rp_undrop(rp_undrop_),
|
||||
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
index_t seqlen_qk_start)
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
|
||||
auto randval_dram_window = [&]() {
|
||||
if constexpr(IsFwd)
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
randval_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
|
||||
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
|
||||
}
|
||||
}();
|
||||
|
||||
return randval_dram_window;
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t NIterPerWarp = 1;
|
||||
|
||||
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{};
|
||||
|
||||
constexpr auto randval_block_inner_part_dstr_encoding =
|
||||
typename WarpGemmDispatcher<typename WG::ADataType,
|
||||
typename WG::BDataType,
|
||||
typename WG::CDataType,
|
||||
WG::kM,
|
||||
WG::kN,
|
||||
WG::kK,
|
||||
false,
|
||||
IsWG32>::CWarpDstrEncoding{};
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(randval_block_inner_part_dstr_encoding)>,
|
||||
typename WG::CWarpDstrEncoding>);
|
||||
|
||||
constexpr auto randval_block_part_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
|
||||
randval_block_inner_part_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(randval_block_part_dstr_encode);
|
||||
}
|
||||
|
||||
template <typename BlockGemm,
|
||||
typename RandValOutputDataType,
|
||||
typename PComputeWindow,
|
||||
typename RandValDramWindow>
|
||||
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
|
||||
const index_t start_n0_idx,
|
||||
PComputeWindow& p_compute,
|
||||
RandValDramWindow& randval_dram_window) const
|
||||
{
|
||||
constexpr auto config =
|
||||
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr bool IsWG32 = WG::kM == 32;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
|
||||
constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
|
||||
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
|
||||
constexpr index_t kNPerStep = NWarp * WG::kN;
|
||||
|
||||
// register distribute
|
||||
auto randval_dist_generated =
|
||||
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
auto generate_randval = [&](auto i_m0, auto i_n0) {
|
||||
// Generate random numbers
|
||||
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
|
||||
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
|
||||
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
|
||||
if constexpr(IsWG32)
|
||||
{
|
||||
// Generate the whole 32x32 tile at once (each tile consists of random numbers
|
||||
// taken from a separate subsequence of Philox)
|
||||
const unsigned long long ph_subsequence =
|
||||
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
|
||||
const index_t ph_offset = get_lane_id();
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
|
||||
// MIterPerWarp is equal to 1 or 2)
|
||||
const unsigned long long ph_subsequence =
|
||||
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
|
||||
const index_t subtile_m0 = wg_m0 % 2;
|
||||
if constexpr(get_warp_size() == 32)
|
||||
{
|
||||
const index_t ph_offset = (get_lane_id() & 15) +
|
||||
(((get_lane_id() >> 4) & 1) << 5) +
|
||||
((wg_n0 % 2) << 4);
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
if constexpr(MIterPerWarp == 1)
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
|
||||
ph.get_random_16x8(random_uint8_t, ph_subsequence);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
|
||||
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
|
||||
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
|
||||
if constexpr(MIterPerWarp == 1)
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
|
||||
ph.get_random_4x8(
|
||||
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
|
||||
ph.get_random_8x8(
|
||||
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto randval_dist_generated_spans =
|
||||
decltype(randval_dist_generated)::get_distributed_spans();
|
||||
int i_random_idx = 0;
|
||||
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
|
||||
});
|
||||
});
|
||||
return randval_dist_generated;
|
||||
};
|
||||
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
const auto randval = generate_randval(i_m0, i_n0);
|
||||
// Drop values of P based on the generated probabilities, negative sign is used to
|
||||
// distinguish such values later in bwd pipeline.
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
constexpr auto p_idx0 =
|
||||
tile_distributed_index<i_m0 * MIterPerWarp +
|
||||
idx0.impl_.template at<0>(),
|
||||
idx0.impl_.template at<1>(),
|
||||
idx0.impl_.template at<2>()>{};
|
||||
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
|
||||
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
|
||||
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
|
||||
? p_compute[p_idx]
|
||||
: -p_compute[p_idx];
|
||||
});
|
||||
});
|
||||
// save to Global
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {kMPerStep, 0});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
|
||||
}
|
||||
});
|
||||
if constexpr(IsStoreRandval)
|
||||
{
|
||||
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
|
||||
}
|
||||
}
|
||||
|
||||
const unsigned long long ph_seed;
|
||||
const unsigned long long ph_head_offset;
|
||||
const float rp_undrop;
|
||||
const uint8_t p_undrop_in_uint8_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
642
include/ck_tile/ops/unified_attention/block/block_masking.hpp
Normal file
642
include/ck_tile/ops/unified_attention/block/block_masking.hpp
Normal file
@@ -0,0 +1,642 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GenericAttentionMaskEnum
|
||||
{
|
||||
NO_MASK = 0,
|
||||
|
||||
// below enum could be causal, or sliding window
|
||||
MASK_FROM_TOP_LEFT = 1,
|
||||
MASK_FROM_BOTTOM_RIGHT = 2,
|
||||
|
||||
// this enum maybe not used by xformer/FA, since it's hard to
|
||||
// specify left/right window for varlen case. put it here for
|
||||
// debug purpose
|
||||
MASK_GENERIC,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
/* generic Attention Mask Coordinate
|
||||
use x(horizontal axis), y(vertical axis) to describe mask.
|
||||
top-left corner is origin
|
||||
|
||||
x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
|
||||
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
|
||||
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
|
||||
1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
l=7,-1/r=0(tl) l=7,-1/r=0(br)
|
||||
|
||||
x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
|
||||
1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
|
||||
1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
|
||||
* 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
|
||||
* * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
|
||||
* * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
|
||||
l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
|
||||
l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
|
||||
|
||||
x=4/y=-1 x=6/y=-1 x=8/y=-1
|
||||
* * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
|
||||
* * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
|
||||
* * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
|
||||
* * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
|
||||
* * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
|
||||
|
||||
x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
|
||||
* * * * * * * * 1 * * * * * * *
|
||||
* * * * * * * * 1 1 * * 1 * * *
|
||||
* * * * * * * * 1 1 1 * 1 1 * *
|
||||
1 * * * * * * * 1 1 1 1 1 1 1 *
|
||||
1 1 * * * * * * 1 1 1 1 1 1 1 1
|
||||
|
||||
Validations:
|
||||
x + y > 1 (x + y >= 2)
|
||||
|
||||
Note:
|
||||
y = seq_q, x = 1 -> top-left
|
||||
y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
|
||||
y < seq_q, x < seq_k -> local-attn
|
||||
y = seq_q, x = seq_k -> no mask
|
||||
|
||||
*/
|
||||
namespace impl {
|
||||
template <bool IsMasking_, bool IsLocal_> struct MaskName;
|
||||
template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
|
||||
template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
|
||||
template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
|
||||
template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
template <bool IsMasking_ = true, bool IsLocal_ = false>
|
||||
struct GenericAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
|
||||
// else only upper-right could have mask
|
||||
|
||||
static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: GenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1;
|
||||
index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
else
|
||||
{
|
||||
return i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// TODO: no need to check begin
|
||||
return (i_tile_left + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound =
|
||||
i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl {
|
||||
template <bool IsMasking_> struct SimplifiedMaskName;
|
||||
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
|
||||
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
// this version only have 2 variation: masking and non-masking
|
||||
// This is more friendly to codegen (e.g. need generate less kernel)
|
||||
// ... with the trade-off that may have more instruction in causal mode
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
|
||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
|
||||
number<TileHeight> height,
|
||||
number<TileWidth> width,
|
||||
index_t num_splits,
|
||||
index_t i_split) const
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split;
|
||||
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
|
||||
|
||||
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
|
||||
ck_tile::min(origin_end, split_end));
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max(-x + i_x + 1, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min(i_x + XTile - 1 + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
|
||||
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
|
||||
|
||||
// TODO: no need to check begin
|
||||
return (i_x + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_x_end = i_x + TileWidth;
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
|
||||
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
|
||||
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl {
|
||||
template <bool IsMasking_> struct SimplifiedRatioMaskName;
|
||||
template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
|
||||
template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
// this version is used for cases that the step length of y-direction changes greater than one. It
|
||||
// means that the mask is not a regular triangular matrix.
|
||||
|
||||
// clang-format off
|
||||
/* y_ratio is used to describe the step length of y-direction changes
|
||||
in certain performance optimization scenarios like merging seqlen
|
||||
and qk_head_ratio, for example:
|
||||
|
||||
x=1/y=6/y_ratio=2(top-left)
|
||||
1 * * * * * * *
|
||||
1 * * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 * * * * * *
|
||||
1 1 1 * * * * *
|
||||
1 1 1 * * * * *
|
||||
|
||||
*/
|
||||
// clang-format on
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedRatioAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
|
||||
static constexpr const char* name = impl::SimplifiedRatioMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedRatioAttentionMask(
|
||||
index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
|
||||
: SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
|
||||
/*x_=*/x_,
|
||||
/*y_total_=*/y_total_,
|
||||
/*x_total_=*/x_total_,
|
||||
/*y_real_=*/y_real_,
|
||||
/*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
|
||||
/*y_ratio_mdiv_=*/y_ratio_mdiv_)
|
||||
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedRatioAttentionMask(index_t y_,
|
||||
index_t x_,
|
||||
index_t y_total_,
|
||||
index_t x_total_,
|
||||
index_t y_real_,
|
||||
index_t y_ratio_,
|
||||
mdiv y_ratio_mdiv_)
|
||||
: y(y_),
|
||||
x(x_),
|
||||
y_total(y_total_),
|
||||
x_total(x_total_),
|
||||
y_real(y_real_),
|
||||
y_ratio(y_ratio_),
|
||||
y_ratio_mdiv(y_ratio_mdiv_)
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = -y_real +
|
||||
static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
|
||||
1;
|
||||
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
uint32_t y_offset = i_y + YTile - 1;
|
||||
index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// to get the loop length along Y axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
|
||||
// TODO: y_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongY(index_t i_x, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, y_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along Y tile by tile
|
||||
index_t y_start = [&]() {
|
||||
index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
|
||||
return (tmp / YTile) * YTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t y_end = [&]() {
|
||||
index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
|
||||
return ((tmp + YTile - 1) / YTile) * YTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(y_start, y_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
|
||||
index_t x_start = -y_real + x_tmp + 1;
|
||||
index_t x_end = min(x_tmp + x,
|
||||
x_total); // need min in case x is padded
|
||||
return i_x < x_start || i_x >= x_end || i_y >= y_total;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
|
||||
|
||||
return (i_x + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_x_end = i_x + TileWidth;
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
uint32_t y_tmp = static_cast<uint32_t>(i_y);
|
||||
bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
|
||||
x_total); // consider right pad
|
||||
bool bottom_left_edge =
|
||||
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
|
||||
return top_right_edge || bottom_left_edge;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
// y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
|
||||
index_t y_real;
|
||||
index_t y_ratio;
|
||||
mdiv y_ratio_mdiv;
|
||||
};
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
// local is left_size >=0 or right_size >=0
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
// TODO: below should all use sgpr arithmetic
|
||||
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
|
||||
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
|
||||
|
||||
left_size = left_size < 0 ? left_size_tmp : left_size;
|
||||
right_size = right_size < 0 ? right_size_tmp : right_size;
|
||||
|
||||
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
|
||||
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
|
||||
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct PositionEncodingEnum
|
||||
{
|
||||
NO = 0,
|
||||
ALIBI = 1,
|
||||
};
|
||||
|
||||
/*
|
||||
VERTICAL:
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
|
||||
TOP_LEFT(but negative):
|
||||
[0] 1 2 3 4 5
|
||||
1 [0] 1 2 3 4
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
|
||||
FROM_BOTTOM_RIGHT(but negative):
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
4 3 2 1 [0] 1
|
||||
5 4 3 2 1 [0]
|
||||
*/
|
||||
|
||||
enum struct AlibiMode
|
||||
{
|
||||
VERTICAL = 0,
|
||||
FROM_TOP_LEFT = 1, // keep sync with mask enum
|
||||
FROM_BOTTOM_RIGHT = 2,
|
||||
};
|
||||
|
||||
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
|
||||
struct Alibi
|
||||
{
|
||||
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
|
||||
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
|
||||
|
||||
// RowMajor here means if pixel within the same thread are along the row, or col
|
||||
// this may impact the performance of update(), while the result are the same.
|
||||
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
|
||||
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
|
||||
index_t y_total_,
|
||||
index_t x_total_,
|
||||
AlibiMode mode_ = AlibiMode::VERTICAL)
|
||||
{
|
||||
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
|
||||
|
||||
shift_left_up = [&]() {
|
||||
if(RowMajor)
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
|
||||
}
|
||||
}();
|
||||
shift_right_down = [&]() {
|
||||
if(RowMajor)
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
|
||||
}
|
||||
}();
|
||||
mode = mode_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
|
||||
|
||||
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
|
||||
{
|
||||
if constexpr(LogMaxSadOprndSize <= 16)
|
||||
{
|
||||
return sad_u16(
|
||||
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
|
||||
}
|
||||
|
||||
return sad_u32(x, y, acc);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
|
||||
{
|
||||
if constexpr(RowMajor)
|
||||
{
|
||||
// at least 3 instructions per row
|
||||
index_t current_zero_point =
|
||||
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
|
||||
|
||||
// for every threads, most of the pixels are along the row, below operation should be
|
||||
// the main hot spot.
|
||||
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
|
||||
bit_cast<uint32_t>(col_idx + shift_left_up),
|
||||
0));
|
||||
pixel += slope * position;
|
||||
}
|
||||
else
|
||||
{
|
||||
// at least 3 instructions per col;
|
||||
index_t current_zero_point = mode == AlibiMode::VERTICAL
|
||||
? row_idx + col_idx + shift_right_down
|
||||
: col_idx + shift_right_down;
|
||||
|
||||
// for every threads, most of the pixels are along the col, below operation should be
|
||||
// the main hot spot.
|
||||
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
|
||||
bit_cast<uint32_t>(row_idx + shift_left_up),
|
||||
0));
|
||||
pixel += slope * position;
|
||||
}
|
||||
}
|
||||
|
||||
DataType slope; // float?
|
||||
index_t shift_left_up; // always possitive
|
||||
index_t shift_right_down; // always possitive
|
||||
AlibiMode mode;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct EmptyPositionEncoding
|
||||
{
|
||||
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
// local is left_size >=0 or right_size >=0
|
||||
template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
|
||||
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
|
||||
index_t window_left_size,
|
||||
index_t window_right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
GenericAttentionMaskEnum mask_enum)
|
||||
{
|
||||
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
|
||||
// totally OK to use constexpr
|
||||
bool is_causal = window_left_size < 0 && window_right_size == 0;
|
||||
AlibiMode alibi_mode =
|
||||
is_causal ? AlibiMode::VERTICAL
|
||||
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
|
||||
return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
|
||||
}
|
||||
|
||||
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
||||
// Do we need a device version?
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
|
||||
{
|
||||
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
|
||||
float start = std::powf(
|
||||
static_cast<float>(2),
|
||||
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
|
||||
|
||||
std::vector<DataType> rtn;
|
||||
for(auto i = 0; i < n; i++)
|
||||
{
|
||||
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
|
||||
}
|
||||
return rtn;
|
||||
};
|
||||
if(is_power_of_two_integer(nheads))
|
||||
{
|
||||
// power of 2 calculation
|
||||
return get_slopes_power_of_2(nheads);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
|
||||
auto v0 = get_slopes_power_of_2(closest_power_of_2);
|
||||
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
|
||||
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
|
||||
std::vector<DataType> sliced;
|
||||
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
|
||||
{
|
||||
if(i % 2 == 0)
|
||||
sliced.push_back(vec[i]);
|
||||
}
|
||||
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
|
||||
return sliced_2;
|
||||
}(v1, nheads - closest_power_of_2);
|
||||
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
|
||||
return v0;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class RotaryEmbeddingEnum
|
||||
{
|
||||
NONE = 0,
|
||||
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
|
||||
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
|
||||
};
|
||||
|
||||
template <RotaryEmbeddingEnum>
|
||||
struct RotaryEmbeddingEnumToStr;
|
||||
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
|
||||
{
|
||||
static constexpr const char* name = "inter";
|
||||
};
|
||||
template <>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
|
||||
{
|
||||
static constexpr const char* name = "half";
|
||||
};
|
||||
|
||||
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
|
||||
struct BlockRotaryEmbedding
|
||||
{
|
||||
template <typename DistributedTensor,
|
||||
typename OtherDramBlockWindow,
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
|
||||
OtherDramBlockWindow other_window,
|
||||
RotaryCosDramBlockWindow rotary_cos_window,
|
||||
RotarySinDramBlockWindow rotary_sin_window,
|
||||
index_t rotary_dim,
|
||||
index_t thread_end)
|
||||
{
|
||||
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
|
||||
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
|
||||
|
||||
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
|
||||
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
const bool is_left = (thread_end <= (rotary_dim / 2));
|
||||
|
||||
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
auto other_tile = load_tile(other_window);
|
||||
|
||||
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
|
||||
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
|
||||
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
|
||||
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
|
||||
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
|
||||
|
||||
const auto cos =
|
||||
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
|
||||
const auto sin =
|
||||
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
|
||||
|
||||
tile.thread_buf_[idx] =
|
||||
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,358 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// assume that we have only 1 page-block/tensor view
|
||||
template <typename TensorView>
|
||||
struct TrivialPageBlockNavigator
|
||||
{
|
||||
using DataType = typename TensorView::DataType;
|
||||
using WindowOrigin = multi_index<2>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
|
||||
: tensor_view(tensor_view_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin) const
|
||||
{
|
||||
return make_tuple(/*block_index=*/0,
|
||||
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename TileDistribution>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin,
|
||||
const TileDistribution& tile_distribution) const
|
||||
{
|
||||
return make_tuple(
|
||||
/*block_index=*/0,
|
||||
ck_tile::make_tile_window(
|
||||
tensor_view, window_lengths, window_origin, tile_distribution));
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE static index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
return /*block_index=*/0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t /*block_index*/,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t /*id*/) const
|
||||
{
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t /*block_index*/,
|
||||
TileWindow /*tile_window*/,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& /*step*/) const
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin)
|
||||
{
|
||||
return global_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
|
||||
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
|
||||
{
|
||||
return local_window_origin;
|
||||
}
|
||||
|
||||
private:
|
||||
TensorView tensor_view;
|
||||
};
|
||||
|
||||
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
|
||||
// if tile window on last page-block
|
||||
template <typename DataType_, index_t VirtualDim, typename TensorView>
|
||||
struct PageBlockNavigator
|
||||
{
|
||||
using DataType = DataType_;
|
||||
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
|
||||
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
|
||||
using WindowOrigin = multi_index<2>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
|
||||
long_index_t block_stride_,
|
||||
long_index_t fixed_offset_,
|
||||
const int32_t* physical_block_indices_,
|
||||
index_t num_blocks_,
|
||||
index_t page_block_size_,
|
||||
const TensorView& complete_view_,
|
||||
const TensorView& last_view_)
|
||||
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
|
||||
block_stride(block_stride_),
|
||||
fixed_offset(fixed_offset_),
|
||||
physical_block_indices(physical_block_indices_),
|
||||
num_blocks(num_blocks_),
|
||||
page_block_size(page_block_size_),
|
||||
complete_view(complete_view_),
|
||||
last_view(last_view_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename WindowLengths>
|
||||
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin) const
|
||||
{
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
auto new_tile_window =
|
||||
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
|
||||
window_lengths,
|
||||
local_window_origin);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename WindowLengths, typename TileDistribution>
|
||||
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
|
||||
const WindowOrigin& window_origin,
|
||||
const TileDistribution& tile_distribution) const
|
||||
{
|
||||
const index_t block_index = get_block_index(window_origin);
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
|
||||
|
||||
auto new_tile_window =
|
||||
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
|
||||
window_lengths,
|
||||
local_window_origin,
|
||||
tile_distribution);
|
||||
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
|
||||
|
||||
return make_tuple(block_index, new_tile_window);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
move_tile_window(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step,
|
||||
index_t id) const
|
||||
{
|
||||
ck_tile::move_tile_window(tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, tile_window.get_window_origin());
|
||||
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
|
||||
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(local_window_origin);
|
||||
if(id >= 0)
|
||||
tile_window.set_bottom_tensor_view_data_ptr(physical_blocks + id * block_stride +
|
||||
fixed_offset);
|
||||
else
|
||||
tile_window.set_bottom_tensor_view_data_ptr(nullptr);
|
||||
|
||||
return new_block_index;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
prefetch_table_id(index_t block_index,
|
||||
TileWindow& tile_window,
|
||||
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
|
||||
{
|
||||
auto local_tile_window = tile_window; // not affect origin window
|
||||
ck_tile::move_tile_window(local_tile_window, step);
|
||||
|
||||
const WindowOrigin global_window_origin =
|
||||
to_global_window_origin(block_index, local_tile_window.get_window_origin());
|
||||
const index_t new_block_index = get_block_index(global_window_origin);
|
||||
|
||||
if(new_block_index < num_blocks)
|
||||
{
|
||||
return physical_block_indices[new_block_index];
|
||||
}
|
||||
else
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
|
||||
{
|
||||
return block_index == num_blocks - 1;
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
|
||||
const TileWindow& tile_window) const
|
||||
{
|
||||
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
|
||||
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
|
||||
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
|
||||
}
|
||||
|
||||
template <typename TileWindow>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
|
||||
{
|
||||
const multi_index<2> step = [&]() {
|
||||
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
return make_multi_index(origin_diff, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_multi_index(0, origin_diff);
|
||||
}
|
||||
}();
|
||||
|
||||
/// TODO: only update necessary attributes
|
||||
tile_window.bottom_tensor_view_.desc_ =
|
||||
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
|
||||
tile_window.set_window_origin(tile_window.get_window_origin() + step);
|
||||
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE WindowOrigin
|
||||
to_local_window_origin(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<0>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(length - page_block_size * num_complete_blocks,
|
||||
global_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t length = global_window_origin.at(number<1>{});
|
||||
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
|
||||
return make_multi_index(global_window_origin.at(number<0>{}),
|
||||
length - page_block_size * num_complete_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE WindowOrigin
|
||||
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
|
||||
{
|
||||
if constexpr(VirtualDim == 0)
|
||||
{
|
||||
return make_multi_index(block_index * page_block_size +
|
||||
local_window_origin.at(number<0>{}),
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_multi_index(local_window_origin.at(number<0>{}),
|
||||
block_index * page_block_size +
|
||||
local_window_origin.at(number<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_HOST_DEVICE
|
||||
DataType* get_block_ptr(index_t block_index) const
|
||||
{
|
||||
if(block_index < num_blocks)
|
||||
{
|
||||
return physical_blocks + physical_block_indices[block_index] * block_stride +
|
||||
fixed_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
|
||||
{
|
||||
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
|
||||
}
|
||||
|
||||
DataType* physical_blocks;
|
||||
long_index_t block_stride;
|
||||
long_index_t fixed_offset;
|
||||
|
||||
const int32_t* physical_block_indices;
|
||||
index_t num_blocks;
|
||||
index_t page_block_size;
|
||||
|
||||
TensorView complete_view;
|
||||
TensorView last_view;
|
||||
};
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
|
||||
{
|
||||
return TrivialPageBlockNavigator<TensorView>(tensor_view);
|
||||
}
|
||||
|
||||
template <typename DataType, index_t VirtualDim, typename TensorView>
|
||||
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
|
||||
long_index_t block_stride,
|
||||
long_index_t fixed_offset,
|
||||
const int32_t* physical_block_indices,
|
||||
index_t num_blocks,
|
||||
index_t page_block_size,
|
||||
const TensorView& complete_view,
|
||||
const TensorView& last_view)
|
||||
{
|
||||
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
|
||||
block_stride,
|
||||
fixed_offset,
|
||||
physical_block_indices,
|
||||
num_blocks,
|
||||
page_block_size,
|
||||
complete_view,
|
||||
last_view);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
302
include/ck_tile/ops/unified_attention/block/variants.hpp
Normal file
302
include/ck_tile/ops/unified_attention/block/variants.hpp
Normal file
@@ -0,0 +1,302 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck_tile/core/numeric/math.hpp>
|
||||
#include <ck_tile/core/numeric/type_convert.hpp>
|
||||
|
||||
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0
|
||||
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1
|
||||
|
||||
#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT
|
||||
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
|
||||
#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
namespace internal {
|
||||
__device__ inline float
|
||||
exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
/// NOTICE: Make sure softmax_scale is stored in SGPR
|
||||
float result, numerator, denominator;
|
||||
asm volatile(
|
||||
"v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
|
||||
"v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
|
||||
"v_rcp_f32_e32 %[denominator], %[denominator]\n"
|
||||
"v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
|
||||
"v_mul_f32_e32 %[result], %[numerator], %[denominator]"
|
||||
: [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
|
||||
: [softmax_scale] "s"(softmax_scale),
|
||||
[logits] "v"(logits),
|
||||
[logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
|
||||
return result;
|
||||
#else
|
||||
return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
|
||||
#endif
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
template <typename ImplMask>
|
||||
struct StandardAttentionParams
|
||||
{
|
||||
__device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_)
|
||||
: impl_mask(impl_mask_), sm_scale(sm_scale_)
|
||||
{
|
||||
}
|
||||
|
||||
const ImplMask& impl_mask;
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
template <typename ImplMask, bool UseExp2 = false>
|
||||
struct LogitsSoftCapParams
|
||||
{
|
||||
__device__
|
||||
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
|
||||
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
|
||||
{
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap);
|
||||
}
|
||||
else
|
||||
{
|
||||
logits_soft_cap_rcp = 0.f;
|
||||
}
|
||||
|
||||
// move computation here to prevent compiler from generating inefficient instruction
|
||||
// sequence
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
logits_soft_cap = log2e_v<float> * logits_soft_cap;
|
||||
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
|
||||
}
|
||||
}
|
||||
|
||||
__host__
|
||||
LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_)
|
||||
: impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_)
|
||||
{
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
logits_soft_cap_rcp = 1.f / logits_soft_cap;
|
||||
}
|
||||
else
|
||||
{
|
||||
logits_soft_cap_rcp = 0.f;
|
||||
}
|
||||
|
||||
// move computation here to prevent compiler from generating inefficient instruction
|
||||
// sequence
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
logits_soft_cap = log2e_v<float> * logits_soft_cap;
|
||||
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_,
|
||||
float sm_scale_,
|
||||
float logits_soft_cap_,
|
||||
float logits_soft_cap_rcp_)
|
||||
: impl_mask(impl_mask_),
|
||||
sm_scale(sm_scale_),
|
||||
logits_soft_cap(logits_soft_cap_),
|
||||
logits_soft_cap_rcp(logits_soft_cap_rcp_)
|
||||
{
|
||||
// move computation here to prevent compiler from generating inefficient instruction
|
||||
// sequence
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
logits_soft_cap = log2e_v<float> * logits_soft_cap;
|
||||
logits_soft_cap_rcp = sm_scale * log2e_rcp_v<float> * logits_soft_cap_rcp;
|
||||
}
|
||||
}
|
||||
|
||||
const ImplMask& impl_mask;
|
||||
float sm_scale;
|
||||
float logits_soft_cap;
|
||||
float logits_soft_cap_rcp;
|
||||
};
|
||||
|
||||
struct StandardAttention
|
||||
{
|
||||
__device__ __host__ StandardAttention() = default;
|
||||
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
|
||||
{
|
||||
return type_convert<float>(q) * params.sm_scale;
|
||||
}
|
||||
|
||||
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
|
||||
/// qo_idx/kv_idx.
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params,
|
||||
T logits,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
/*uint32_t qo_idx, uint32_t kv_idx,*/
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return logits;
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool UseExp2 = false>
|
||||
struct LogitsSoftCap
|
||||
{
|
||||
__device__ __host__ LogitsSoftCap() = default;
|
||||
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
|
||||
{
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
return q;
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<float>(q) * params.sm_scale;
|
||||
}
|
||||
}
|
||||
|
||||
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
|
||||
/// qo_idx/kv_idx.
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T LogitsTransform(const Params& params,
|
||||
T logits,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
/*uint32_t qo_idx, uint32_t kv_idx,*/
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
return params.logits_soft_cap *
|
||||
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return type_convert<float>(logits) *
|
||||
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr uint32_t CUSTOM_MASK = 1U;
|
||||
constexpr uint32_t SLIDING_WINDOW = 2U;
|
||||
constexpr uint32_t LOGITS_SOFT_CAP = 4U;
|
||||
constexpr uint32_t ALIBI = 8U;
|
||||
|
||||
template <uint32_t VARIANT_CODE, bool UseExp2 = false>
|
||||
struct ComposedAttention
|
||||
{
|
||||
static constexpr bool use_exp2 = UseExp2;
|
||||
|
||||
static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0;
|
||||
|
||||
__device__ __host__ ComposedAttention() = default;
|
||||
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T QueryTransform(const Params& params, T q) const
|
||||
{
|
||||
if constexpr(use_logits_soft_cap && UseExp2)
|
||||
{
|
||||
return q;
|
||||
}
|
||||
return type_convert<float>(q) * params.sm_scale;
|
||||
}
|
||||
|
||||
/// NOTICE: For better performance, we simpliy transform thread buffer without calculating
|
||||
/// qo_idx/kv_idx.
|
||||
template <typename Params, typename T>
|
||||
__device__ __forceinline__ T LogitsTransform(const Params& params,
|
||||
T logits,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
/*uint32_t qo_idx, uint32_t kv_idx,*/
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
if constexpr(use_logits_soft_cap)
|
||||
{
|
||||
if constexpr(UseExp2)
|
||||
{
|
||||
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
return params.logits_soft_cap *
|
||||
tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return type_convert<float>(logits) *
|
||||
rcp<float>(1.f +
|
||||
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return logits;
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
__device__ __forceinline__ bool LogitsMask(const Params& params,
|
||||
[[maybe_unused]] uint32_t batch_idx,
|
||||
uint32_t qo_idx,
|
||||
uint32_t kv_idx,
|
||||
[[maybe_unused]] uint32_t qo_head_idx,
|
||||
[[maybe_unused]] uint32_t kv_head_idx) const
|
||||
{
|
||||
return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,450 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdV3Kernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
// The attention is default causal
|
||||
struct UnifiedAttentionCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
|
||||
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head_q;
|
||||
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
|
||||
// if this param is larger than 1, indicate MQA/GQA case
|
||||
ck_tile::index_t num_queries_per_kv;
|
||||
// scales
|
||||
float scale_s;
|
||||
float scale;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
float scale_out;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks;
|
||||
ck_tile::index_t query_stride_0;
|
||||
ck_tile::index_t query_stride_1;
|
||||
ck_tile::index_t stride_k_cache_0;
|
||||
ck_tile::index_t stride_k_cache_1;
|
||||
ck_tile::index_t stride_k_cache_2;
|
||||
ck_tile::index_t stride_k_cache_3;
|
||||
ck_tile::index_t stride_v_cache_0;
|
||||
ck_tile::index_t stride_v_cache_1;
|
||||
ck_tile::index_t stride_v_cache_2;
|
||||
ck_tile::index_t stride_v_cache_3;
|
||||
ck_tile::index_t output_stride_0;
|
||||
ck_tile::index_t output_stride_1;
|
||||
ck_tile::index_t HEAD_SIZE_PADDED;
|
||||
};
|
||||
|
||||
|
||||
struct UnifiedAttentionVarlenKargs
|
||||
{
|
||||
const int32_t* block_tables_ptr;
|
||||
const int32_t* seq_lens_ptr; // seq len in each batch
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
ck_tile::index_t num_seqs; // number of batches for q
|
||||
ck_tile::index_t BLOCK_SIZE; // Block size for kv cache. to 2's exponent????
|
||||
ck_tile::index_t BLOCK_Q; // Block size for kv cache. to 2's exponent????
|
||||
};
|
||||
|
||||
struct Kargs {
|
||||
UnifiedAttentionCommonKargs unifiedAttentionCommonKargs;
|
||||
UnifiedAttentionVarlenKargs unifiedAttentionVarlenKargs;
|
||||
};
|
||||
|
||||
// using Kargs = FmhaFwdGroupModeKargs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(
|
||||
const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t num_queries_per_kv,
|
||||
float scale_s,
|
||||
float scale,
|
||||
float scale_k,
|
||||
float scale_v,
|
||||
float scale_out,
|
||||
ck_tile::index_t total_num_q_blocks,
|
||||
ck_tile::index_t query_stride_0,
|
||||
ck_tile::index_t query_stride_1,
|
||||
ck_tile::index_t stride_k_cache_0,
|
||||
ck_tile::index_t stride_k_cache_1,
|
||||
ck_tile::index_t stride_k_cache_2,
|
||||
ck_tile::index_t stride_k_cache_3,
|
||||
ck_tile::index_t stride_v_cache_0,
|
||||
ck_tile::index_t stride_v_cache_1,
|
||||
ck_tile::index_t stride_v_cache_2,
|
||||
ck_tile::index_t stride_v_cache_3,
|
||||
ck_tile::index_t output_stride_0,
|
||||
ck_tile::index_t output_stride_1,
|
||||
const int32_t* block_tables_ptr,
|
||||
const int32_t* seq_lens_ptr,
|
||||
const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t BLOCK_SIZE,
|
||||
ck_tile::index_t BLOCK_Q
|
||||
)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
num_queries_per_kv,
|
||||
static_cast<float>(scale_s * ck_tile::log2e_v<>),
|
||||
scale,
|
||||
scale_k,
|
||||
scale_v,
|
||||
scale_out,
|
||||
total_num_q_blocks,
|
||||
query_stride_0,
|
||||
query_stride_1,
|
||||
stride_k_cache_0,
|
||||
stride_k_cache_1,
|
||||
stride_k_cache_2,
|
||||
stride_k_cache_3,
|
||||
stride_v_cache_0,
|
||||
stride_v_cache_1,
|
||||
stride_v_cache_2,
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
{
|
||||
block_tables_ptr,
|
||||
seq_lens_ptr,
|
||||
query_start_len_ptr,
|
||||
num_seqs,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_Q,
|
||||
}};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize2D(ck_tile::index_t num_kv_heads,
|
||||
ck_tile::index_t total_num_q_blocks)
|
||||
{
|
||||
return dim3(num_kv_heads * total_num_q_blocks, 0, 0);
|
||||
}
|
||||
|
||||
// CK_TILE_HOST static constexpr auto GridSize3D(ck_tile::index_t num_kv_heads,
|
||||
// ck_tile::index_t total_num_q_blocks)
|
||||
// {
|
||||
// // TODO: fix 3D grid
|
||||
// return dim2(num_kv_heads, total_num_q_blocks);
|
||||
// }
|
||||
|
||||
// Binary search to find the sequence index for a given target index
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t
|
||||
find_seq_idx(const int32_t* query_start_len_ptr,
|
||||
ck_tile::index_t target_idx,
|
||||
ck_tile::index_t num_seqs,
|
||||
ck_tile::index_t BLOCK_Q,
|
||||
bool use_q_block_mode)
|
||||
{
|
||||
ck_tile::index_t left = 0;
|
||||
ck_tile::index_t right = num_seqs;
|
||||
|
||||
while (left < right)
|
||||
{
|
||||
ck_tile::index_t mid = (left + right) / 2;
|
||||
ck_tile::index_t val = query_start_len_ptr[mid];
|
||||
ck_tile::index_t mid_val = use_q_block_mode ? (val / BLOCK_Q + mid) : val;
|
||||
|
||||
if (mid_val <= target_idx)
|
||||
{
|
||||
left = mid + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
right = mid;
|
||||
}
|
||||
}
|
||||
|
||||
return left - 1;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
RemapTileIndices(const ck_tile::index_t pid, const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t NUM_XCDS = 8;
|
||||
const index_t GRID_MN = kargs.unifiedAttentionCommonKargs.total_num_q_blocks *
|
||||
(kargs.unifiedAttentionCommonKargs.num_head_q);
|
||||
|
||||
// Number of pids per XCD in the new arrangement
|
||||
const index_t pids_per_xcd = (GRID_MN + NUM_XCDS - 1) / NUM_XCDS;
|
||||
|
||||
// When GRID_MN cannot divide NUM_XCDS, some xcds will have
|
||||
// pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
|
||||
// We calculate the number of xcds that have pids_per_xcd pids as tall_xcds
|
||||
index_t tall_xcds = GRID_MN % NUM_XCDS;
|
||||
tall_xcds = tall_xcds == 0 ? NUM_XCDS : tall_xcds;
|
||||
|
||||
// Compute current XCD and local pid within the XCD
|
||||
const index_t xcd = pid % NUM_XCDS;
|
||||
const index_t local_pid = pid / NUM_XCDS;
|
||||
|
||||
// Calculate new pid based on the new grouping
|
||||
index_t remapped_pid = 0; // Initialize to avoid constexpr error
|
||||
if(xcd < tall_xcds)
|
||||
{
|
||||
remapped_pid = xcd * pids_per_xcd + local_pid;
|
||||
}
|
||||
else
|
||||
{
|
||||
remapped_pid = tall_xcds * pids_per_xcd +
|
||||
(xcd - tall_xcds) * (pids_per_xcd - 1) +
|
||||
local_pid;
|
||||
}
|
||||
|
||||
return remapped_pid;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const ck_tile::index_t pid, const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
ck_tile::index_t total_num_q_blocks = kargs.unifiedAttentionCommonKargs.total_num_q_blocks;
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_tile_m = pid % total_num_q_blocks; // Query block index
|
||||
const index_t i_tile_n = pid / total_num_q_blocks; // Head index
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
ck_tile::index_t pid = blockIdx.x;
|
||||
|
||||
pid = RemapTileIndices(pid, kargs);
|
||||
|
||||
// divide problem
|
||||
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
|
||||
|
||||
const index_t seq_idx = find_seq_idx(
|
||||
kargs.unifiedAttentionVarlenKargs.query_start_len_ptr, q_block_global_idx, kargs.unifiedAttentionVarlenKargs.num_seqs, kargs.unifiedAttentionCommonKargs.BLOCK_Q, true
|
||||
); // which batch
|
||||
|
||||
const index_t q_block_start_idx = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
|
||||
|
||||
const index_t q_block_local_idx = amd_wave_read_first_lane(q_block_global_idx - q_block_start_idx);
|
||||
|
||||
const index_t cur_batch_in_all_start_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx]);
|
||||
const index_t cur_batch_in_all_stop_index = amd_wave_read_first_lane(kargs.unifiedAttentionVarlenKargs.query_start_len_ptr[seq_idx + 1]);
|
||||
|
||||
const index_t cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index;
|
||||
|
||||
// TODO check if we get the block size info from pipeline
|
||||
if (q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q >= cur_batch_query_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t query_pos = q_block_local_idx * kargs.unifiedAttentionVarlenKargs.BLOCK_Q;
|
||||
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.unifiedAttentionCommonKargs.q_ptr) +
|
||||
static_cast<long_index_t>(kv_head_idx) * kargs.unifiedAttentionCommonKargs.num_queries_per_kv * kargs.unifiedAttentionCommonKargs.query_stride_1 +
|
||||
static_cast<long_index_t>(cur_batch_in_all_start_index) * kargs.unifiedAttentionCommonKargs.query_stride_0;
|
||||
// const KDataType* k_ptr =
|
||||
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
// batch_offset_k;
|
||||
// const VDataType* v_ptr =
|
||||
// reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
// batch_offset_v;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.unifiedAttentionVarlenKargs.),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<FmhaPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen_k, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
{0, i_n1});
|
||||
|
||||
// lse
|
||||
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
LSEDataType* lse_ptr =
|
||||
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
lse_ptr,
|
||||
make_tuple(kargs.seqlen_q),
|
||||
make_tuple(1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(lse_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
else
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_o, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
o_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,603 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 4;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
|
||||
// TODO: GetAlignment*() currently didn't consider if need padding or not
|
||||
// so in pipeline still need check padding requirement
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxReadSizeInBytes = 16;
|
||||
#else
|
||||
constexpr index_t MaxReadSizeInBytes = 4;
|
||||
#endif
|
||||
return MaxReadSizeInBytes / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t MaxReadSizeInBytes = 16;
|
||||
#else
|
||||
constexpr index_t MaxReadSizeInBytes = 4;
|
||||
#endif
|
||||
return MaxReadSizeInBytes / sizeof(VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// TODO: this is for 3d layout
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
return 16 / sizeof(KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// TODO: this is for 3d layout
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
return 16 / sizeof(VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
constexpr index_t N2 = NumWarps;
|
||||
constexpr index_t K0 = LanesPerK;
|
||||
constexpr index_t K1 = KVector;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr index_t N0 = NumIssues;
|
||||
constexpr index_t N1 = LaneGroups;
|
||||
constexpr index_t N2 = NumWarps;
|
||||
constexpr index_t K0 = LanesPerK;
|
||||
constexpr index_t K1 = KVector;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
|
||||
return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::UnifiedAttentionShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
// compute the endcoding before transpose
|
||||
constexpr auto v_block_dstr =
|
||||
make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(v_block_dstr_encode),
|
||||
typename Problem::VDataType>::TransposedDstrEncode{});
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kM0,
|
||||
Problem::UnifiedAttentionShape::kN0,
|
||||
Problem::UnifiedAttentionShape::kK0>,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
|
||||
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
|
||||
return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::KDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::SaccDataType, float>)
|
||||
{
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
|
||||
/// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::UnifiedAttentionShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm),
|
||||
GemmLoopOrder::MNK>;
|
||||
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::UnifiedAttentionShape::kM0,
|
||||
Problem::UnifiedAttentionShape::kN1,
|
||||
Problem::UnifiedAttentionShape::kK1>,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1WarpTile>>;
|
||||
/// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass
|
||||
/// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::UnifiedAttentionShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::UnifiedAttentionShape::Gemm1BlockWarps,
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
|
||||
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
|
||||
// Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return k_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK0;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kKLdsPadInBytes /
|
||||
sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
// this function assume K/V can share smem
|
||||
constexpr index_t SingleKSize = [&]() {
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}();
|
||||
|
||||
constexpr index_t SingleVSize = [&]() {
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
}();
|
||||
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, ck_tile::index_t IBuf = 0>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
|
||||
// Optimize this for lds_read speed
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
|
||||
make_tuple(number<NumIssues>{}, // n0
|
||||
number<LaneGroups>{}, // n1
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
|
||||
// TODO this layout is hard coded, and will be used in async copy buffer view load
|
||||
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
|
||||
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return v_lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
/// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension
|
||||
constexpr index_t kNPerBlock = Problem::UnifiedAttentionShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::UnifiedAttentionShape::kN1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::UnifiedAttentionShape::NumWarps;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
kVLdsPadInBytes /
|
||||
sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
|
||||
number<NumWarps>{}, // n2
|
||||
number<LaneGroups>{}, // n1
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
|
||||
MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
|
||||
constexpr index_t k_element_space_size =
|
||||
MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
|
||||
|
||||
static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
|
||||
MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
|
||||
constexpr index_t v_element_space_size =
|
||||
MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
|
||||
|
||||
static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
|
||||
GetSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
/// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() &
|
||||
/// MakeVLdsBlockDescriptor()
|
||||
static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
|
||||
constexpr index_t kv_element_space_size_in_bytes =
|
||||
GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
|
||||
|
||||
return kv_element_space_size_in_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 4 * GetSmemSizeKV<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,42 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaPipelineEnum
|
||||
{
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
struct BlockFmhaPipelineEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qr";
|
||||
};
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
|
||||
{
|
||||
static constexpr const char* name = "qr_async";
|
||||
};
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qs";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
|
||||
{
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_rotary_embedding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename BiasDataType_,
|
||||
typename RandValOutputDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
typename ODataType_,
|
||||
typename UnifiedAttentionShape_,
|
||||
typename Traits_>
|
||||
struct UnifiedAttentionPipelineProblem
|
||||
{
|
||||
// TODO kM0 and KN1??
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
// first gemm accumulation dtype
|
||||
using SaccDataType = remove_cvref_t<SaccDataType_>;
|
||||
// Softmax dtype
|
||||
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
using RandValOutputDataType = remove_cvref_t<RandValOutputDataType_>;
|
||||
// data type for A matrix of second gemm
|
||||
using PDataType = remove_cvref_t<PDataType_>;
|
||||
// data type for second gemm accumulation
|
||||
using OaccDataType = remove_cvref_t<OaccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using UnifiedAttentionShape = remove_cvref_t<UnifiedAttentionShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = UnifiedAttentionShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = UnifiedAttentionShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = UnifiedAttentionShape::NumWarps * get_warp_size();
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user