mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Remove the comparing of row/col to max_uih_len in masking
This commit is contained in:
@@ -575,7 +575,7 @@ struct HstuAttentionFwdKernel
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(mask.max_uih_len, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_q, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
@@ -584,20 +584,20 @@ struct HstuAttentionFwdKernel
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(mask.max_uih_len, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_k, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
@@ -612,7 +612,7 @@ struct HstuAttentionFwdKernel
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(mask.max_uih_len, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
@@ -732,7 +732,7 @@ struct HstuAttentionFwdKernel
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(mask.max_uih_len, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_o, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -388,20 +388,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(q_origin.at(number<0>{}) + kM0 > mask.max_uih_len)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, type_convert<GemmAccDataType>(0), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
auto s = cast_tile<CompDataType>(s_acc); // S{j}
|
||||
auto s = cast_tile<CompDataType>(s_acc);
|
||||
|
||||
tile_elementwise_inout(f_silu, s);
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ struct HstuBlockMaskWithLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row >= max_uih_len || col >= max_uih_len)
|
||||
return false;
|
||||
@@ -87,6 +87,25 @@ struct HstuBlockMaskWithLocal
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
// masking codes in device don't have to compare row/col with max_uih_len, since
|
||||
// buffer_load_xxx instruction is able to return zero for out-of-boundary access
|
||||
CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row < contextual_seqlen)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
|
||||
return result;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal>
|
||||
@@ -123,7 +142,7 @@ struct HstuBlockMaskNoLocal
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
CK_TILE_HOST constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row >= max_uih_len || col >= max_uih_len)
|
||||
return false;
|
||||
@@ -140,6 +159,23 @@ struct HstuBlockMaskNoLocal
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
// masking codes in device don't have to compare row/col with max_uih_len, since
|
||||
// buffer_load_xxx instruction is able to return zero for out-of-boundary access
|
||||
CK_TILE_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row < contextual_seqlen)
|
||||
return true;
|
||||
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool result = (row >= col);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
@@ -157,10 +193,12 @@ CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
auto max_uih_len_ = seqlen_;
|
||||
|
||||
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
|
||||
max_uih_len_ -= num_target;
|
||||
auto max_uih_len_ = [&]() {
|
||||
if(contextual_seqlen_ > 0)
|
||||
return seqlen_ - (contextual_seqlen_ - 1) - num_target;
|
||||
else
|
||||
return seqlen_ - num_target;
|
||||
}();
|
||||
|
||||
return HstuBlockMaskType{
|
||||
contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_};
|
||||
@@ -170,10 +208,12 @@ template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
|
||||
{
|
||||
auto max_uih_len_ = seqlen_;
|
||||
|
||||
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
|
||||
max_uih_len_ -= num_target;
|
||||
auto max_uih_len_ = [&]() {
|
||||
if(contextual_seqlen_ > 0)
|
||||
return seqlen_ - (contextual_seqlen_ - 1) - num_target;
|
||||
else
|
||||
return seqlen_ - num_target;
|
||||
}();
|
||||
|
||||
return HstuBlockMaskType{contextual_seqlen_, max_uih_len_};
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user