mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Fix identified issues for implementing kHasDropout == true
This commit is contained in:
@@ -579,15 +579,22 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
|
||||
|
||||
if constexpr(kUseGroup)
|
||||
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
|
||||
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
|
||||
if constexpr(!kHasDropout)
|
||||
{
|
||||
num_tile_in_seqlen += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
if constexpr(kUseGroup)
|
||||
{
|
||||
// always assume minfull_attn_seqlen > 0 since we don't know its exact value
|
||||
// for each group
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
}
|
||||
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim)
|
||||
{
|
||||
@@ -763,35 +770,44 @@ struct HstuAttentionFwdKernel
|
||||
bool is_tile_in_first_split = true;
|
||||
index_t i_m0;
|
||||
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
|
||||
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
|
||||
if constexpr(!kHasDropout)
|
||||
{
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
seqlen_in_first_split =
|
||||
kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0));
|
||||
index_t num_tile_in_first_split =
|
||||
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0));
|
||||
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
i_m0 = is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
i_m0 =
|
||||
is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
@@ -953,6 +969,36 @@ struct HstuAttentionFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, which will cause compiling issue when
|
||||
// used inside BlockDropout::Run(). Also we need this dram window to provide
|
||||
// start_m0_idx used in BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
|
||||
make_tuple(kargs.seqlen_kv, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(null_randval_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{}),
|
||||
{i_m0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -1014,6 +1060,7 @@ struct HstuAttentionFwdKernel
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1029,6 +1076,7 @@ struct HstuAttentionFwdKernel
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1067,6 +1115,7 @@ struct HstuAttentionFwdKernel
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1082,6 +1131,7 @@ struct HstuAttentionFwdKernel
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
|
||||
@@ -592,6 +592,75 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
|
||||
#ifdef __gfx950__
|
||||
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
|
||||
(WarpGemmM == 32 && WarpGemmK == 16),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#else
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
|
||||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#endif
|
||||
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not supported data types!");
|
||||
}
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2Hack_0<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// Same as GetQKBlockGemm but with kN0 (instead of kN0Sub) as the N tile dimension.
|
||||
// This is used as the BlockGemm template argument to BlockDropout::Run() so that
|
||||
// kNPerBlock = kN0, ensuring dropout is applied to the full pcomp_tile [kM0, kN0]
|
||||
// rather than only the first kN0Sub columns.
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKCombinedBlockGemm()
|
||||
{
|
||||
using GemmProblem = BlockGemmProblem<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN0,
|
||||
Problem::HstuAttentionTileSetting::kQKHeaddim>,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0BlockWarps,
|
||||
typename Problem::HstuAttentionTileSetting::Gemm0WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr((std::is_same_v<typename Problem::QKVDataType, half_t> ||
|
||||
std::is_same_v<typename Problem::QKVDataType, bf16_t>) &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
constexpr index_t WarpGemmM =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{});
|
||||
constexpr index_t WarpGemmK =
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
|
||||
#ifdef __gfx950__
|
||||
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
|
||||
(WarpGemmM == 32 && WarpGemmK == 16),
|
||||
|
||||
@@ -540,15 +540,22 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
ck_tile::index_t num_tile_in_seqlen =
|
||||
ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0);
|
||||
|
||||
if constexpr(kUseGroup)
|
||||
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
|
||||
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
|
||||
if constexpr(!kHasDropout)
|
||||
{
|
||||
num_tile_in_seqlen += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
if constexpr(kUseGroup)
|
||||
{
|
||||
// always assume minfull_attn_seqlen > 0 since we don't know its exact value
|
||||
// for each group
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_minfull_attn_seqlen)
|
||||
num_tile_in_seqlen += 1;
|
||||
};
|
||||
}
|
||||
|
||||
if constexpr(HstuAttentionPipeline::kN1 < HstuAttentionPipeline::kQKHeaddim)
|
||||
{
|
||||
@@ -760,35 +767,44 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
bool is_tile_in_first_split = true;
|
||||
index_t i_m0;
|
||||
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
// when kHasDropout is true, we should not give special consideration to minfull_attn_seqlen
|
||||
// > 0, since BlockDropout requires each workgroup has a kM0 aligned i_m0 position
|
||||
if constexpr(!kHasDropout)
|
||||
{
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
|
||||
if(kargs.min_full_attn_seqlen > 0)
|
||||
{
|
||||
seqlen_in_first_split = kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
// need consider for cases where min_full_attn_seqlen be bigger than max_uih_len
|
||||
if(kargs.seqlen_q - num_target > kargs.min_full_attn_seqlen)
|
||||
{
|
||||
seqlen_in_first_split =
|
||||
kargs.seqlen_q - num_target - kargs.min_full_attn_seqlen;
|
||||
|
||||
index_t num_tile_in_first_split =
|
||||
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0));
|
||||
index_t num_tile_in_first_split =
|
||||
__builtin_amdgcn_readfirstlane(ck_tile::integer_divide_ceil(
|
||||
seqlen_in_first_split, HstuAttentionPipeline::kM0));
|
||||
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
is_tile_in_first_split = (i_tile_m < num_tile_in_first_split);
|
||||
|
||||
i_m0 = is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
i_m0 =
|
||||
is_tile_in_first_split
|
||||
? __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0)
|
||||
: __builtin_amdgcn_readfirstlane((i_tile_m - num_tile_in_first_split) *
|
||||
HstuAttentionPipeline::kM0) +
|
||||
seqlen_in_first_split;
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
seqlen_in_first_split = 0;
|
||||
is_tile_in_first_split = false;
|
||||
|
||||
// adjust the min_full_attn_seqlen to be passed to HstuBlockMask constructor
|
||||
kargs.min_full_attn_seqlen = kargs.seqlen_q - num_target;
|
||||
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
};
|
||||
}
|
||||
else
|
||||
i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
@@ -956,6 +972,36 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, which will cause compiling issue when
|
||||
// used inside BlockDropout::Run(). Also we need this dram window to provide
|
||||
// start_m0_idx used in BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(seqlen_q_in_ctrl, kargs.seqlen_kv),
|
||||
make_tuple(kargs.seqlen_kv, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(null_randval_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN0>{}),
|
||||
{i_m0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
@@ -1020,6 +1066,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1035,6 +1082,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1076,6 +1124,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
@@ -1091,6 +1140,7 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
lse_acc_dram_window,
|
||||
null_randval_window,
|
||||
seqlen_k_start,
|
||||
seqlen_k_end,
|
||||
mask,
|
||||
|
||||
@@ -114,12 +114,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename NullRandValDramWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -155,6 +157,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
|
||||
|
||||
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
@@ -265,31 +269,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, will cause compiling issue when used
|
||||
// inside BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
|
||||
null_randval_window_tmp, seqlen_k_start);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
@@ -396,17 +377,21 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
tile_elementwise_inout([&](auto& x) { x = x * type_convert<CompDataType>(scale_p); },
|
||||
pcomp_tile);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
// STAGE 3, Gemm_1 ( O = P@V )
|
||||
|
||||
@@ -115,12 +115,14 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename NullRandValDramWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -156,6 +158,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem, true /*kUseTrLoad*/>();
|
||||
|
||||
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
@@ -272,31 +276,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, will cause compiling issue when used
|
||||
// inside BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
|
||||
null_randval_window_tmp, seqlen_k_start);
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
@@ -383,18 +364,22 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
detail::scale_tile_in_pack(pcomp_tile, scale_p);
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
// check whether first V-LdsBufer overlap with last K-LdsBuffer,
|
||||
|
||||
@@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename NullRandValDramWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
@@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -164,6 +166,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPVTBlockGemm<Problem>();
|
||||
|
||||
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
@@ -296,31 +300,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, will cause compiling issue when used
|
||||
// inside BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
|
||||
null_randval_window_tmp, seqlen_k_start);
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
@@ -522,17 +503,21 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
|
||||
});
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -119,6 +119,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEorLSEaccDramBlockWindow,
|
||||
typename NullRandValDramWindowTmp,
|
||||
typename HstuMask>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
@@ -126,6 +127,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEorLSEaccDramBlockWindow& lse_or_lse_acc_dram_block_window, // M0 tile
|
||||
NullRandValDramWindowTmp& null_randval_window_tmp, // M0*N0 tile
|
||||
index_t seqlen_k_start,
|
||||
index_t seqlen_k_end,
|
||||
HstuMask& mask,
|
||||
@@ -165,6 +167,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
constexpr auto gemm_1 =
|
||||
Policy::template GetPVTBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
using Gemm0Combined = decltype(Policy::template GetQKCombinedBlockGemm<Problem>());
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile<kM0, kN0Sub>());
|
||||
@@ -303,31 +307,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
// need to make a tile window from this null_randval_dram since the null_tile_window
|
||||
// does not have store_tile() over-loaded, will cause compiling issue when used
|
||||
// inside BlockDropout::Run()
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
auto null_randval_window = dropout.template MakeRandvalDramWindow<Gemm0Combined>(
|
||||
null_randval_window_tmp, seqlen_k_start);
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
@@ -528,18 +509,22 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
l(i_idx) = l[i_idx] + rowsum_p[i_idx];
|
||||
});
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
dropout.template Run<Gemm0Combined, CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
seqlen_k_curr += kN0;
|
||||
|
||||
auto p = cast_tile<PDataType>(pcomp_tile);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -p_drop=0.2"
|
||||
|
||||
attn_scale=0
|
||||
if [ $# -ge 1 ]; then
|
||||
attn_scale=$1
|
||||
fi
|
||||
|
||||
ndist=0
|
||||
|
||||
if [ $# -ge 2 ]; then
|
||||
ndist=$2
|
||||
fi
|
||||
|
||||
for dtype in "fp16" "bf16"; do
|
||||
set -x
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=0 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=50 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
set +x
|
||||
done
|
||||
Reference in New Issue
Block a user