mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Completely remove the dependency to include/ck_tile/ops/fmha/ops headers
This commit is contained in:
@@ -10,7 +10,6 @@
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
@@ -32,8 +31,8 @@ template <typename InOutDataType,
|
||||
ck_tile::index_t MaxK>
|
||||
struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
@@ -45,16 +44,16 @@ struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
HstuAttentionShape,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionShape::kN0 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
|
||||
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionTileSetting::kN0 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
|
||||
@@ -23,20 +23,18 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
|
||||
|
||||
using HstuAttentionTileShape = remove_cvref_t<typename Problem::HstuAttentionTileShape>;
|
||||
using VLayout = remove_cvref_t<typename HstuAttentionTileShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
|
||||
using VLayout = remove_cvref_t<typename HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileShape::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileShape::kN0;
|
||||
static constexpr index_t kK0 = HstuAttentionTileShape::kK0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileShape::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileShape::kSubQKHeaddim;
|
||||
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
|
||||
static constexpr index_t kK0 = HstuAttentionTileSetting::kK0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -283,7 +281,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kM0>{}, number<kK1>{}),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
|
||||
@@ -4,7 +4,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
|
||||
#include "block_gemm_areg_bsmem_creg_v2_hack_0.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v2_hack_1.hpp"
|
||||
@@ -12,10 +21,6 @@
|
||||
namespace ck_tile {
|
||||
|
||||
struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ -1,
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKVLdsBuffers()
|
||||
@@ -29,8 +34,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr index_t kBlockGemmM = GetQKBlockGemmSingleRepM<Problem>();
|
||||
|
||||
return BlockGemm::
|
||||
template MakeABlockTileDistribution<kBlockGemmM, Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
kBlockGemmM,
|
||||
Problem::HstuAttentionTileSetting::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -39,8 +45,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -63,6 +69,24 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::MakeCBlockTile().get_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
@@ -75,11 +99,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -99,11 +123,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -123,12 +147,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -155,8 +179,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -177,10 +201,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -222,7 +246,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
@@ -282,7 +306,7 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
@@ -311,8 +335,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
@@ -394,8 +418,8 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
@@ -421,12 +445,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -556,11 +580,11 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -614,12 +638,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
|
||||
{
|
||||
// This tile-distribuiton only used when V layout is RowMajor
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
@@ -644,27 +668,29 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetQKBlockGemmSingleRepM()
|
||||
{
|
||||
return Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}) *
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
return Problem::HstuAttentionTileSetting::Gemm0BlockWarps::at(number<0>{}) *
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kK1,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem = BlockGemmProblem<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kK1,
|
||||
Problem::HstuAttentionTileSetting::kQKHeaddim>,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
|
||||
@@ -709,12 +735,12 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2Hack_0<GemmProblem, BlockGemmPolicy>{};
|
||||
@@ -725,23 +751,23 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmSingleRepN()
|
||||
{
|
||||
return Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) *
|
||||
Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
return Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}) *
|
||||
Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
using GemmProblem = BlockGemmProblem<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN1,
|
||||
Problem::HstuAttentionTileSetting::kK1>,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
|
||||
@@ -759,21 +785,21 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
}
|
||||
}();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(warp_gemm)>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
@@ -803,10 +829,16 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout()
|
||||
{
|
||||
return 0;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0),
|
||||
return max(GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(),
|
||||
GetSmemSizeQ<Problem>());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_tile_setting_define.hpp"
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdBlockTile;
|
||||
@@ -51,48 +51,52 @@ struct HstuAttentionFwdBlockTile<256>
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdShape;
|
||||
struct HstuAttentionFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<32>
|
||||
struct HstuAttentionFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<64>
|
||||
struct HstuAttentionFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<128>
|
||||
struct HstuAttentionFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<256>
|
||||
struct HstuAttentionFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
@@ -32,8 +31,8 @@ template <typename InOutDataType,
|
||||
ck_tile::index_t MaxK>
|
||||
struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
@@ -45,15 +44,15 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
HstuAttentionShape,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionTileSetting::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionTileSetting::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
|
||||
@@ -21,7 +21,7 @@ template <typename InOutDataType_,
|
||||
bool kHasBias_,
|
||||
bool kHasDropout_,
|
||||
typename HstuMask_, // encoding Causal and Local, contextual masking
|
||||
typename AttentionTileShape_,
|
||||
typename AttentionTileSetting_,
|
||||
typename Traits_>
|
||||
struct HstuAttentionFwdPipelineProblem
|
||||
{
|
||||
@@ -35,10 +35,6 @@ struct HstuAttentionFwdPipelineProblem
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
|
||||
// to be compatible with ck_tile existing policy codes
|
||||
using QDataType = QKVDataType;
|
||||
using KDataType = QKVDataType;
|
||||
using VDataType = QKVDataType;
|
||||
using SaccDataType = GemmAccDataType;
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using PDataType = QKVDataType;
|
||||
|
||||
@@ -48,15 +44,13 @@ struct HstuAttentionFwdPipelineProblem
|
||||
|
||||
using HstuMask = remove_cvref_t<HstuMask_>;
|
||||
|
||||
using HstuAttentionTileShape = remove_cvref_t<AttentionTileShape_>;
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
// Keep the name compatible with ck_tile existing policy codes, to be changed
|
||||
using BlockFmhaShape = HstuAttentionTileShape;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileShape_::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = AttentionTileShape_::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = AttentionTileShape_::NumWarps * get_warp_size();
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
|
||||
{
|
||||
if(len == 96)
|
||||
return 128;
|
||||
if(len == 160)
|
||||
return 256;
|
||||
if(len == 192)
|
||||
return 192;
|
||||
|
||||
// only length of 96, 160, 192 and power-of-two is supported
|
||||
if(!(len & (len - 1)))
|
||||
return len;
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
bool IsVLayoutRowMajor_>
|
||||
struct HstuAttentionFwdTileSettingClass
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>;
|
||||
using Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>;
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
|
||||
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
|
||||
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
|
||||
static constexpr index_t kQKHeaddim =
|
||||
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
|
||||
// once (or repeately load Q as a whole tile)
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user