Implementation of hstu attention pipeline using trload for v on mi350

This commit is contained in:
Qianfeng Zhang
2025-10-27 14:54:36 +00:00
parent a464269bb6
commit 207e6f10b8
8 changed files with 1153 additions and 72 deletions

View File

@@ -0,0 +1,247 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemTrLoadCRegV2Hack_1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
constexpr auto b_warp_dstr_encode =
typename InputTileDistributionTraits<typename WG::BWarpDstrEncoding,
BDataType>::TransposedDstrEncode{};
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kK>{}, number<WG::kN>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN},
make_static_tile_distribution(b_warp_dstr_encode));
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto I0 = number<0>{};
using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0)));
statically_indexed_array<statically_indexed_array<b_warp_tensor_type, KIterPerWarp>,
NIterPerWarp>
b_warp_tensors;
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(I0)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(I0)(kIter),
{kIter * KPerBlockPerIter, 0 * NPerBlockPerIter});
b_warp_tensors(I0)(kIter) = load_tile_transpose(b_warp_windows(I0)(kIter));
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter < NIterPerWarp - 1)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(number<nIter + 1>{})(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(number<nIter + 1>{})(kIter),
{kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter});
b_warp_tensors(number<nIter + 1>{})(kIter) =
load_tile_transpose(b_warp_windows(number<nIter + 1>{})(kIter));
});
};
__builtin_amdgcn_sched_barrier(0);
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter][kIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return make_static_tile_distribution(a_block_dstr_encode);
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -16,6 +16,7 @@
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_fwd_pipeline.hpp"
#include "hstu_attention_fwd_trload_pipeline.hpp"
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
@@ -32,6 +33,12 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
@@ -43,6 +50,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting,
HstuTraits>;
@@ -80,8 +88,11 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kPadSeqLenQ,
kPadHeadDimV>>;
using HstuPipeline =
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
using HstuPipeline = std::conditional_t<
kUseTrLoad,
ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>>;
using HstuKernel =
ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;

View File

@@ -48,6 +48,8 @@ struct HstuAttentionFwdKernel
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
static constexpr bool kHasCausalMask = HstuAttentionPipeline::kHasCausal;
static constexpr bool kUseTrLoad = HstuAttentionPipeline::kUseTrLoad;
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
struct HstuAttentionFwdEmptyKargs
@@ -583,17 +585,27 @@ struct HstuAttentionFwdKernel
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
if constexpr(!kUseTrLoad)
{
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(v_dram_transposed,
make_tuple(number<HstuAttentionPipeline::kN1>{},
number<HstuAttentionPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
return pad_tensor_view(v_dram_transposed,
make_tuple(number<HstuAttentionPipeline::kN1>{},
number<HstuAttentionPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
}
else
{
return pad_tensor_view(v_dram_naive,
make_tuple(number<HstuAttentionPipeline::kK1>{},
number<HstuAttentionPipeline::kN1>{}),
sequence<false, kPadHeadDimV>{});
};
}();
auto q_dram_window =

View File

@@ -40,6 +40,10 @@ struct HstuAttentionFwdPipelineQRKSVS
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == false, "Check failed!");
static constexpr bool kUseTrLoad = false;
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;

View File

@@ -16,6 +16,7 @@
#include "block_gemm_areg_bsmem_creg_v2_hack_0.hpp"
#include "block_gemm_areg_bsmem_creg_v2_hack_1.hpp"
#include "block_gemm_areg_bsmem_trload_creg_v2_hack_1.hpp"
namespace ck_tile {
@@ -71,11 +72,26 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
if constexpr(!Problem::kUseTrLoad)
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
return BlockGemm::template MakeABlockTileDistribution<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
}
else
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto bias_block_dstr_encode =
BlockGemm::template MakeCBlockDistributionEncode<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode);
return bias_block_dstr;
};
}
template <typename Problem>
@@ -148,23 +164,34 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
return kVecLoad;
return kVecLoad;
}
else
{
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
};
}
template <typename Problem>
@@ -195,11 +222,18 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
return N0 * (N1 * kKPerBlock + kKPack);
return N0 * (N1 * kKPerBlock + kKPack);
}
else
{
return kNPerBlock * kKPerBlock;
};
};
template <typename Problem>
@@ -470,43 +504,88 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack >= K2, "Check failed!");
static_assert(kKPack >= K2, "Check failed!");
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<8>{},
number<1>{});
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
return v_lds_block_desc;
}
else
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto XorGroupSize =
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr auto v_lds_block_desc_naive =
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}),
make_tuple(number<VSingleSmemElementSpaceSize>{},
number<kNPerBlock>{},
number<XorGroupSize>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<NumVLdsBuffers>{}),
make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kKPerBlock>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
};
}
template <typename Problem>
@@ -516,26 +595,51 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
else
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<K0, K1, K2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
};
}
// used when kUseTrLoad is false
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
{
@@ -717,7 +821,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
typename Problem::GemmAccDataType,
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
if constexpr(!Problem::kUseTrLoad)
{
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
}
else
{
return BlockGemmARegBSmemTrLoadCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
};
}
template <typename Problem>

View File

@@ -0,0 +1,682 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "hstu_attention_fwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
struct HstuAttentionFwdPipelineQRKSVSTrLoad
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
using CompDataType = remove_cvref_t<typename Problem::CompDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using PDataType = remove_cvref_t<typename Problem::InOutDataType>;
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = HstuAttentionTileSetting::kM0;
static constexpr index_t kN0 = HstuAttentionTileSetting::kN0;
static constexpr index_t kN1 = HstuAttentionTileSetting::kN1;
static constexpr index_t kK1 = HstuAttentionTileSetting::kK1;
static constexpr index_t kQKHeaddim = HstuAttentionTileSetting::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = HstuAttentionTileSetting::kSubQKHeaddim;
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsJagged = Problem::kIsJagged;
static constexpr auto kHasBias = Problem::kHasBias;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == true, "Check failed!");
static constexpr bool kUseTrLoad = true;
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV =
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kGemmSingleRepM = Policy::template GetQKBlockGemmSingleRepM<Problem>();
static constexpr index_t kGemmNumRepM = kM0 / kGemmSingleRepM;
// used by NRepetitions2DEpilogue
static constexpr index_t kGemm1SingleRepN =
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)
{
return 2;
}
else if constexpr(kQKHeaddim == 64)
{
return 2;
}
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
{
if constexpr(kHasBias)
return 2;
else
return 2;
}
else if constexpr(kQKHeaddim == 256)
{
return 1;
}
else
{
return 1;
};
}
}();
static constexpr const char* name = "qr_hstu";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
HstuMask& mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLu result
void* smem_ptr,
DropoutType& dropout) const
{
ignore = q_element_func;
ignore = k_element_func;
static_assert(
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QKVDataType,
remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<QKVDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr bool kUseSoftmax = Problem::kUseSoftmax;
constexpr index_t k1_loops = kN0 / kK1;
constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers<Problem>();
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
// SaccBlockTile size is [kM0, kK1]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kK1>());
using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0>());
using PcompBlockTileType = decltype(cast_tile<CompDataType>(CombineSaccBlockTileType{}));
SaccBlockTileType sacc_tile;
PcompBlockTileType pcomp_tile;
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using MLBlockTileType = decltype(block_tile_reduce<CompDataType>(
PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0}));
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
OaccBlockTileType o_acc;
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kK1>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
using k_tile_type = decltype(load_tile(k_dram_window));
statically_indexed_array<k_tile_type, k1_loops> k_tiles;
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
});
__builtin_amdgcn_sched_barrier(0);
// Q tile in LDS
QKVDataType* q_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto k_lds_read_window =
make_tile_window(k_lds, make_tuple(number<kK1>{}, number<kQKHeaddim>{}), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_write_window, sequence<0, 0>{}, sequence<kK1, kSubQKHeaddim>{}));
using k_lds_read_window_type = decltype(get_slice_tile(
k_lds_read_window, sequence<0, 0>{}, sequence<kK1, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_write_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] = get_slice_tile(k_lds_read_window,
sequence<i_buf * kK1, 0>{},
sequence<(i_buf + 1) * kK1, kQKHeaddim>{});
});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<QKVDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
statically_indexed_array<v_lds_window_type, NumKVLdsBuffers> v_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
v_lds_windows[i_buf] = get_slice_tile(
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
// reduction function for softmax
const auto f_silu = [&](CompDataType& x) {
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
if constexpr(std::is_same_v<CompDataType, float>)
{
x = x * __builtin_amdgcn_rcpf(one + __expf(-x));
}
else
{
x = x / (one + exp(-x));
}
};
const auto f_exp = [&](CompDataType x) {
if constexpr(std::is_same_v<CompDataType, float>)
{
return __expf(x);
}
else
{
return exp(x);
}
};
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK1>{}),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_s_waitcnt(0xc07f);
// the following codes will not generate actual instructions by the compiler
set_slice_tile(q_tile,
q_reg_tiles[i_rep],
sequence<i_rep * kGemmSingleRepM, 0>{},
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
clear_tile(o_acc);
if constexpr(kUseSoftmax)
{
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
};
};
q_tile = tile_elementwise_in(q_element_func, q_tile);
auto seqlen_k_curr = seqlen_k_start;
__builtin_amdgcn_sched_barrier(0x00000001);
// ensure all q_reg_tiles[] have been loaded from LDS, so the LDS can be reused by k_tile
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0x00000001);
using v_tile_type = decltype(load_tile(v_dram_window));
statically_indexed_array<v_tile_type, k1_loops> v_tiles;
do
{
// STAGE 1, Gemm_0 ( S = Q@K )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(k_lds_write_windows[i_k1],
tile_elementwise_in(k_element_func, k_tiles[i_k1]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load v_tiles used in current iteration
v_tiles[i_k1] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_k1 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
set_slice_tile(pcomp_tile,
tmp_tile,
sequence<0, i_k1 * kK1>{},
sequence<kM0, (i_k1 + 1) * kK1>{});
});
__builtin_amdgcn_sched_barrier(0x00000001);
// STAGE 2, scale_s, add bias, mask, siLU
if constexpr(kHasBias)
{
const auto bias_tile = load_tile(bias_dram_window);
tile_elementwise_inout(
[&scale_s, &bias_element_func](auto& x, const auto& y) {
x = x * scale_s + type_convert<CompDataType>(bias_element_func(y));
},
pcomp_tile,
bias_tile);
move_tile_window(bias_dram_window, {0, kN0});
}
else
{
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile);
}
if constexpr(!kUseSoftmax)
{
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(!mask.IsTokenPairInsideMask(row, col))
{
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
};
});
});
}
tile_elementwise_inout(f_silu, pcomp_tile);
tile_elementwise_inout(
[&](auto& x) { x = x * type_convert<CompDataType>(scale_p); }, pcomp_tile);
}
else
{
if(!mask.IsFullTileInsideMask(
q_origin.at(number<0>{}), seqlen_k_curr, number<kN0>{}, number<kM0>{}))
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(!mask.IsTokenPairInsideMask(row, col) || col >= seqlen_k_end)
{
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
};
});
});
}
else
{
constexpr auto p_spans = PcompBlockTileType::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto col = seqlen_k_curr + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(col >= seqlen_k_end)
{
pcomp_tile(i_j_idx) = -numeric<CompDataType>::infinity();
};
});
});
};
auto m_local = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_max, -numeric<CompDataType>::infinity());
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m;
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local);
constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = type_convert<CompDataType>(0.0f);
});
}
else
{
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]);
});
}
});
auto rowsum_p = block_tile_reduce<CompDataType>(
pcomp_tile, sequence<1>{}, f_sum, CompDataType{0});
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// adjust o_acc[] according to the update between m and m_old
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(m[i_idx] == -numeric<CompDataType>::infinity())
{
l(i_idx) = rowsum_p[i_idx];
}
else
{
const auto tmp = f_exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
}
});
};
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
}
auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, pcomp_tile));
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
// this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4
if constexpr((k1_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers)
{
__builtin_amdgcn_s_barrier();
};
// STAGE 3, Gemm_1 ( O = P@V )
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tiles[number<i_k1>{}]));
__builtin_amdgcn_sched_barrier(0x00000001);
// load k_tiles used by next iteration
k_tiles[i_k1] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {kK1, 0});
__builtin_amdgcn_sched_barrier(0x00000001);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0x00000001);
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_windows[number<i_k1 + 2>{}]);
});
} while(seqlen_k_curr < seqlen_k_end);
if constexpr(kUseSoftmax)
{
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if(m[i_idx] == -numeric<CompDataType>::infinity())
o_acc(i_j_idx) = 0.0f;
else
o_acc(i_j_idx) *= 1.0f / l[i_idx];
});
});
};
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename HstuMask>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
HstuMask mask,
float scale_s, // scaling value exerted on the immediate Q@K result
float scale_p, // scaling value exerted on the SiLU result
void* smem_ptr,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
mask,
scale_s,
scale_p,
smem_ptr,
dropout);
}
};
} // namespace ck_tile

View File

@@ -16,6 +16,7 @@
#include "hstu_attention_pipeline_problem.hpp"
#include "hstu_attention_traits.hpp"
#include "hstu_attention_fwd_pipeline.hpp"
#include "hstu_attention_fwd_trload_pipeline.hpp"
#include "hstu_attention_fwd_kernel.hpp"
#include "hstu_attention_epilogue.hpp"
@@ -32,6 +33,12 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
#else
static constexpr bool kUseTrLoad = false;
#endif
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
@@ -43,6 +50,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting,
HstuTraits>;
@@ -74,8 +82,11 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kPadSeqLenQ,
kPadHeadDimV>>;
using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
using HstuPipeline = std::conditional_t<
kUseTrLoad,
ck_tile::HstuAttentionFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
RunWithKernel<HstuKernel>(param, stream);
});

View File

@@ -22,6 +22,7 @@ template <typename InOutDataType_,
bool kHasDropout_,
bool kHasCausal_,
bool kUseSoftmax_,
bool kUseTrLoad_, // use transposed loading to load V tile from lds to vgprs
typename AttentionTileSetting_,
typename Traits_>
struct HstuAttentionFwdPipelineProblem
@@ -44,6 +45,7 @@ struct HstuAttentionFwdPipelineProblem
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kHasCausal = kHasCausal_;
static constexpr bool kUseSoftmax = kUseSoftmax_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;