Completely remove the dependency to include/ck_tile/ops/fmha/ops headers

This commit is contained in:
Qianfeng Zhang
2025-06-22 11:29:03 +00:00
parent 4fa6474254
commit 463a19859a
7 changed files with 240 additions and 147 deletions

View File

@@ -10,7 +10,6 @@
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/ops/fmha.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
@@ -32,8 +31,8 @@ template <typename InOutDataType,
ck_tile::index_t MaxK>
struct batched_forward_causal_local_bias_dropout_dispatch
{
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
@@ -45,16 +44,16 @@ struct batched_forward_causal_local_bias_dropout_dispatch
kHasBias,
kHasDropout,
HstuMask,
HstuAttentionShape,
HstuAttentionTileSetting,
HstuTraits>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionShape::kN0 == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access

View File

@@ -23,20 +23,18 @@ struct HstuAttentionFwdPipelineQRKSVS
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
using HstuAttentionTileShape = remove_cvref_t<typename Problem::HstuAttentionTileShape>;
using VLayout = remove_cvref_t<typename HstuAttentionTileShape::VLayout>;
static constexpr bool kQLoadOnce = true;
static_assert(kQLoadOnce == Policy::QLoadOnce);
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
using VLayout = remove_cvref_t<typename HstuAttentionTileSetting::VLayout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = HstuAttentionTileShape::kM0;
static constexpr index_t kN0 = HstuAttentionTileShape::kN0;
static constexpr index_t kK0 = HstuAttentionTileShape::kK0;
static constexpr index_t kN1 = HstuAttentionTileShape::kN1;
static constexpr index_t kK1 = HstuAttentionTileShape::kK1;
static constexpr index_t kQKHeaddim = HstuAttentionTileShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = HstuAttentionTileShape::kSubQKHeaddim;
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
static constexpr index_t kK0 = HstuAttentionTileSetting::kK0;
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
@@ -283,7 +281,7 @@ struct HstuAttentionFwdPipelineQRKSVS
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK1>{}),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)

View File

@@ -4,7 +4,16 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "block_gemm_areg_bsmem_creg_v2_hack_0.hpp"
#include "block_gemm_areg_bsmem_creg_v2_hack_1.hpp"
@@ -12,10 +21,6 @@
namespace ck_tile {
struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopy = */ false,
/* NumPrefetchK = */ -1,
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
@@ -29,8 +34,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
return BlockGemm::
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
return BlockGemm::template MakeABlockTileDistribution<
kBlockGemmM,
Problem::HstuAttentionTileSetting::kQKHeaddim>();
}
template <typename Problem>
@@ -39,8 +45,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim>();
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kQKHeaddim>();
}
template <typename Problem>
@@ -63,6 +69,24 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::MakeCBlockTile().get_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
@@ -75,11 +99,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
@@ -99,11 +123,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
@@ -123,12 +147,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
@@ -155,8 +179,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
@@ -177,10 +201,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -222,7 +246,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
@@ -282,7 +306,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
@@ -311,8 +335,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
@@ -394,8 +418,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
@@ -421,12 +445,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -556,11 +580,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -614,12 +638,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
{
// This tile-distribuiton only used when V layout is RowMajor
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
@@ -644,27 +668,29 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
{
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
return Problem::HstuAttentionTileSetting::Gemm0BlockWarps::at(number<0>{}) *
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK1,
Problem::BlockFmhaShape::kQKHeaddim>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using GemmProblem = BlockGemmProblem<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kK1,
Problem::HstuAttentionTileSetting::kQKHeaddim>,
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
constexpr index_t WarpGemmM =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
@@ -709,12 +735,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
} // TODO - bf8_t
}();
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2Hack_0<GemmProblem, BlockGemmPolicy>{};
@@ -725,23 +751,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN()
{
return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) *
Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
return Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}) *
Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{});
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using GemmProblem =
BlockGemmProblem<typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using GemmProblem = BlockGemmProblem<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::kNumGemm1Warps * get_warp_size(),
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN1,
Problem::HstuAttentionTileSetting::kK1>,
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
@@ -759,21 +785,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
true>{};
}
}();
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
}
@@ -803,10 +829,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
sizeof(typename Problem::QKVDataType);
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
{
return 0;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0),
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
GetSmemSizeQ<Problem>());
}
};

View File

@@ -7,9 +7,9 @@
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_tile_setting_define.hpp"
template <ck_tile::index_t MaxK>
struct HstuAttentionFwdBlockTile;
@@ -51,48 +51,52 @@ struct HstuAttentionFwdBlockTile<256>
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
template <ck_tile::index_t MaxK>
struct HstuAttentionFwdShape;
struct HstuAttentionFwdTileSetting;
template <>
struct HstuAttentionFwdShape<32>
struct HstuAttentionFwdTileSetting<32>
{
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<32>::type,
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<32>::type,
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
};
template <>
struct HstuAttentionFwdShape<64>
struct HstuAttentionFwdTileSetting<64>
{
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<64>::type,
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<64>::type,
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
};
template <>
struct HstuAttentionFwdShape<128>
struct HstuAttentionFwdTileSetting<128>
{
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<128>::type,
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<128>::type,
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
};
template <>
struct HstuAttentionFwdShape<256>
struct HstuAttentionFwdTileSetting<256>
{
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<256>::type,
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<256>::type,
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
};

View File

@@ -10,7 +10,6 @@
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/ops/fmha.hpp>
#include "hstu_attention_bool_switch.hpp"
#include "hstu_attention_fwd_type_config.hpp"
@@ -32,8 +31,8 @@ template <typename InOutDataType,
ck_tile::index_t MaxK>
struct jagged_forward_causal_local_bias_dropout_dispatch
{
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
@@ -45,15 +44,15 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
kHasBias,
kHasDropout,
HstuMask,
HstuAttentionShape,
HstuAttentionTileSetting,
HstuTraits>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
{
constexpr ck_tile::index_t occupancy = -1;
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
// no need to check seqlen_q since it is not used as fastest dim,
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access

View File

@@ -21,7 +21,7 @@ template <typename InOutDataType_,
bool kHasBias_,
bool kHasDropout_,
typename HstuMask_, // encoding Causal and Local, contextual masking
typename AttentionTileShape_,
typename AttentionTileSetting_,
typename Traits_>
struct HstuAttentionFwdPipelineProblem
{
@@ -35,10 +35,6 @@ struct HstuAttentionFwdPipelineProblem
using BiasDataType = remove_cvref_t<BiasDataType_>;
// to be compatible with ck_tile existing policy codes
using QDataType = QKVDataType;
using KDataType = QKVDataType;
using VDataType = QKVDataType;
using SaccDataType = GemmAccDataType;
using OaccDataType = GemmAccDataType;
using PDataType = QKVDataType;
@@ -48,15 +44,13 @@ struct HstuAttentionFwdPipelineProblem
using HstuMask = remove_cvref_t<HstuMask_>;
using HstuAttentionTileShape = remove_cvref_t<AttentionTileShape_>;
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
// Keep the name compatible with ck_tile existing policy codes, to be changed
using BlockFmhaShape = HstuAttentionTileShape;
using Traits = remove_cvref_t<Traits_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = AttentionTileShape_::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = AttentionTileShape_::NumGemm1Warps;
static constexpr index_t kBlockSize = AttentionTileShape_::NumWarps * get_warp_size();
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();
};
} // namespace ck_tile

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
if(len == 192)
return 192;
// only length of 96, 160, 192 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
struct HstuAttentionFwdTileSettingClass
{
using BlockTile = remove_cvref_t<BlockTile_>;
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile