Fix identified issues for implementing kHasDropout == true

This commit is contained in:
Qianfeng Zhang
2026-06-25 16:32:48 +00:00
parent c3df01a519
commit 4250341a92
8 changed files with 339 additions and 168 deletions

View File

@@ -579,15 +579,22 @@ struct HstuAttentionFwdKernel
ck_tile::index_t num_tile_in_seqlen =
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
if constexpr(kUseGroup)
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
if constexpr(!kHasDropout)
{
num_tile_in_seqlen += 1;
}
else
{
if(has_minfull_attn_seqlen)
if constexpr(kUseGroup)
{
// always assume minfull_attn_seqlen > 0 since we don't know its exact value
// for each group
num_tile_in_seqlen += 1;
};
}
else
{
if(has_minfull_attn_seqlen)
num_tile_in_seqlen += 1;
};
}
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim)
{
@@ -763,35 +770,44 @@ struct HstuAttentionFwdKernel
bool is_tile_in_first_split = true;
index_t i_m0;
if(kargs.min_full_attn_seqlen > 0)
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
if constexpr(!kHasDropout)
{
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
if(kargs.min_full_attn_seqlen > 0)
{
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
{
seqlen_in_first_split =
kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
index_t num_tile_in_first_split =
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0));
index_t num_tile_in_first_split =
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0));
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
i_m0 = is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
i_m0 =
is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
}
else
{
seqlen_in_first_split = 0;
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
}
else
{
seqlen_in_first_split = 0;
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
}
else
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
@@ -953,6 +969,36 @@ struct HstuAttentionFwdKernel
}
}();
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, which will cause compiling issue when
// used inside BlockDropout::Run(). Also we need this dram window to provide
// start_m0_idx used in BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
make_tuple(kargs.seqlen_kv, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{}),
sequence<true, true>{});
}();
return make_tile_window(null_randval_dram,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{}),
{i_m0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
@@ -1014,6 +1060,7 @@ struct HstuAttentionFwdKernel
k_dram_window,
v_dram_window,
bias_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1029,6 +1076,7 @@ struct HstuAttentionFwdKernel
v_dram_window,
bias_dram_window,
lse_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1067,6 +1115,7 @@ struct HstuAttentionFwdKernel
k_dram_window,
v_dram_window,
bias_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1082,6 +1131,7 @@ struct HstuAttentionFwdKernel
v_dram_window,
bias_dram_window,
lse_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,

View File

@@ -592,6 +592,75 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),
"Not supported WarpGemm sizes!");
#else
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
"Not supported WarpGemm sizes!");
#endif
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
}
else
{
static_assert(false, "Not supported data types!");
}
}();
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2Hack_0<GemmProblem, BlockGemmPolicy>{};
else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
// Same as GetQKBlockGemm but with kN0 (instead of kN0Sub) as the N tile dimension.
// This is used as the BlockGemm template argument to BlockDropout::Run() so that
// kNPerBlock = kN0, ensuring dropout is applied to the full pcomp_tile [kM0, kN0]
// rather than only the first kN0Sub columns.
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKCombinedBlockGemm()
{
using GemmProblem = BlockGemmProblem<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0,
Problem::HstuAttentionTileSetting::kQKHeaddim>,
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
auto warp_gemm = [&]() {
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
std::is_same_v<typename Problem::QKVDataType, bf16_t>) &&
std::is_same_v<typename Problem::GemmAccDataType, float>)
{
constexpr index_t WarpGemmM =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
constexpr index_t WarpGemmK =
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),

View File

@@ -540,15 +540,22 @@ struct HstuAttentionFwdSplitKVKernel
ck_tile::index_t num_tile_in_seqlen =
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
if constexpr(kUseGroup)
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
if constexpr(!kHasDropout)
{
num_tile_in_seqlen += 1;
}
else
{
if(has_minfull_attn_seqlen)
if constexpr(kUseGroup)
{
// always assume minfull_attn_seqlen > 0 since we don't know its exact value
// for each group
num_tile_in_seqlen += 1;
};
}
else
{
if(has_minfull_attn_seqlen)
num_tile_in_seqlen += 1;
};
}
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim)
{
@@ -760,35 +767,44 @@ struct HstuAttentionFwdSplitKVKernel
bool is_tile_in_first_split = true;
index_t i_m0;
if(kargs.min_full_attn_seqlen > 0)
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
if constexpr(!kHasDropout)
{
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
if(kargs.min_full_attn_seqlen > 0)
{
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
{
seqlen_in_first_split =
kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
index_t num_tile_in_first_split =
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0));
index_t num_tile_in_first_split =
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
seqlen_in_first_split, HstuAttentionPipeline::kM0));
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
i_m0 = is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
i_m0 =
is_tile_in_first_split
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
HstuAttentionPipeline::kM0) +
seqlen_in_first_split;
}
else
{
seqlen_in_first_split = 0;
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
}
else
{
seqlen_in_first_split = 0;
is_tile_in_first_split = false;
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
};
}
else
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
@@ -956,6 +972,36 @@ struct HstuAttentionFwdSplitKVKernel
}
}();
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, which will cause compiling issue when
// used inside BlockDropout::Run(). Also we need this dram window to provide
// start_m0_idx used in BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
make_tuple(kargs.seqlen_kv, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{}),
sequence<true, true>{});
}();
return make_tile_window(null_randval_dram,
make_tuple(number<HstuAttentionPipeline::kM0>{},
number<HstuAttentionPipeline::kN0>{}),
{i_m0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
if constexpr(kHasDropout)
{
@@ -1020,6 +1066,7 @@ struct HstuAttentionFwdSplitKVKernel
k_dram_window,
v_dram_window,
bias_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1035,6 +1082,7 @@ struct HstuAttentionFwdSplitKVKernel
v_dram_window,
bias_dram_window,
lse_acc_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1076,6 +1124,7 @@ struct HstuAttentionFwdSplitKVKernel
k_dram_window,
v_dram_window,
bias_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,
@@ -1091,6 +1140,7 @@ struct HstuAttentionFwdSplitKVKernel
v_dram_window,
bias_dram_window,
lse_acc_dram_window,
null_randval_window,
seqlen_k_start,
seqlen_k_end,
mask,

View File

@@ -114,12 +114,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename NullRandValDramWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -155,6 +157,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
@@ -265,31 +269,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, will cause compiling issue when used
// inside BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
null_randval_window_tmp, seqlen_k_start);
clear_tile(o_acc);
@@ -396,17 +377,21 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
pcomp_tile);
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
{
__builtin_amdgcn_sched_barrier(0);
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
__builtin_amdgcn_sched_barrier(0);
}
seqlen_k_curr += kN0;
auto p = cast_tile<PDataType>(pcomp_tile);
// STAGE 3, Gemm_1 ( O = P@V )

View File

@@ -115,12 +115,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename NullRandValDramWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -156,6 +158,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem, true /*kUseTrLoad*/>();
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
@@ -272,31 +276,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, will cause compiling issue when used
// inside BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
null_randval_window_tmp, seqlen_k_start);
clear_tile(o_acc);
@@ -383,18 +364,22 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
detail::scale_tile_in_pack(pcomp_tile, scale_p);
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
{
__builtin_amdgcn_sched_barrier(0);
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
__builtin_amdgcn_sched_barrier(0);
}
seqlen_k_curr += kN0;
auto p = cast_tile<PDataType>(pcomp_tile);
// check whether first V-LdsBufer overlap with last K-LdsBuffer,

View File

@@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename NullRandValDramWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
@@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -164,6 +166,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
@@ -296,31 +300,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, will cause compiling issue when used
// inside BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
null_randval_window_tmp, seqlen_k_start);
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
@@ -522,17 +503,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
});
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
{
__builtin_amdgcn_sched_barrier(0);
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
__builtin_amdgcn_sched_barrier(0);
}
seqlen_k_curr += kN0;
auto p = cast_tile<PDataType>(pcomp_tile);
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEorLSEaccDramBlockWindow,
typename NullRandValDramWindowTmp,
typename HstuMask>
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
@@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
index_t seqlen_k_start,
index_t seqlen_k_end,
HstuMask& mask,
@@ -165,6 +167,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
constexpr auto gemm_1 =
Policy::template GetPVTBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
@@ -303,31 +307,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem>());
auto null_randval_window = [&]() {
if constexpr(kHasDropout)
{
// need to make a tile window from this null_randval_dram since the null_tile_window
// does not have store_tile() over-loaded, will cause compiling issue when used
// inside BlockDropout::Run()
const auto null_randval_dram = [&]() {
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<uint8_t*>(nullptr),
make_tuple(1, 1),
make_tuple(1, 1),
number<1>{},
number<1>{});
return pad_tensor_view(null_dram_naive,
make_tuple(number<1>{}, number<1>{}),
sequence<true, true>{});
}();
return make_tile_window(
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
}
else
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
null_randval_window_tmp, seqlen_k_start);
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
@@ -528,18 +509,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
});
seqlen_k_curr += kN0;
if constexpr(kHasDropout)
{
__builtin_amdgcn_sched_barrier(0);
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
__builtin_amdgcn_sched_barrier(0);
}
seqlen_k_curr += kN0;
auto p = cast_tile<PDataType>(pcomp_tile);
__builtin_amdgcn_sched_barrier(0x00000001);

View File

@@ -0,0 +1,62 @@
#!/bin/bash
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -p_drop=0.2"
attn_scale=0
if [ $# -ge 1 ]; then
attn_scale=$1
fi
ndist=0
if [ $# -ge 2 ]; then
ndist=$2
fi
for dtype in "fp16" "bf16"; do
set -x
## no masking batched
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## no masking jagged
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context+target
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
set +x
done