Split HstuBlockMasking into HstuBlockMaskWithLocal and HstuBlockMaskNoLocal to save vgprs for non-local situations

This commit is contained in:
Qianfeng Zhang
2025-04-15 14:40:55 +00:00
parent cad1356170
commit 3cd1b13e46
6 changed files with 163 additions and 110 deletions

View File

@@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
double rtol = 2e-2;
double atol = 2e-2;
return ck_tile::make_tuple(rtol, atol);
}

View File

@@ -32,7 +32,7 @@ template <typename InOutDataType,
struct batched_forward_causal_local_bias_dropout_dispatch
{
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<

View File

@@ -11,6 +11,8 @@
#include <utility>
#include <variant>
#include "hstu_block_masking.hpp"
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
@@ -44,7 +46,7 @@ struct HstuAttentionFwdKernel
static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias;
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
using HstuMask = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::HstuMask>;
static constexpr bool kHasMask = HstuMask::IsMasking;
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
@@ -124,15 +126,16 @@ struct HstuAttentionFwdKernel
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
};
struct HstuAttentionFwdBatchModeKargs
: HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdBatchModeBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasLocalMask,
HstuAttentionFwdMaskKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -140,15 +143,16 @@ struct HstuAttentionFwdKernel
ck_tile::index_t batch_stride_o;
};
struct HstuAttentionFwdJaggModeKargs
: HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
std::conditional_t<kHasBias,
HstuAttentionFwdCommonBiasKargs,
HstuAttentionFwdEmptyKargs<0>>,
std::conditional_t<kHasLocalMask,
HstuAttentionFwdMaskKargs,
HstuAttentionFwdEmptyKargs<1>>,
std::conditional_t<kHasDropout,
HstuAttentionFwdCommonDropoutKargs,
HstuAttentionFwdEmptyKargs<2>>
{
const int32_t* seq_offsets_ptr;
};
@@ -224,7 +228,7 @@ struct HstuAttentionFwdKernel
kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias;
}
if constexpr(kHasMask)
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
@@ -366,7 +370,7 @@ struct HstuAttentionFwdKernel
kargs.seq_stride_bias = seq_stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
if constexpr(kHasMask)
if constexpr(kHasLocalMask)
{
kargs.window_size = window_size;
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
@@ -542,14 +546,15 @@ struct HstuAttentionFwdKernel
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
HstuMask mask = [&]() {
if constexpr(kHasMask)
return HstuMask{kargs.seqlen,
kargs.contextual_seqlen,
num_target,
kargs.window_size,
kargs.min_full_attn_seqlen};
if constexpr(kHasLocalMask)
return make_hstu_block_mask_with_local<HstuMask>(kargs.seqlen,
kargs.contextual_seqlen,
num_target,
kargs.window_size,
kargs.min_full_attn_seqlen);
else
return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target};
return make_hstu_block_mask_without_local<HstuMask>(
kargs.seqlen, kargs.contextual_seqlen, num_target);
}();
// for simplicity, batch stride we just modify the pointer

View File

@@ -32,7 +32,7 @@ template <typename InOutDataType,
struct jagged_forward_causal_local_bias_dropout_dispatch
{
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<

View File

@@ -3,14 +3,16 @@
#pragma once
#include <type_traits>
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kUseCausal, bool kUseLocal>
struct HstuBlockMasking
template <bool kUseCausal>
struct HstuBlockMaskWithLocal
{
static constexpr bool IsMasking = (kUseCausal || kUseLocal);
static constexpr bool kUseLocal = true;
static constexpr bool IsMasking = true;
int contextual_seqlen;
int max_uih_len;
@@ -18,30 +20,86 @@ struct HstuBlockMasking
int max_attn_len;
int min_full_attn_seqlen;
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_,
int contextual_seqlen_,
int num_target,
int max_attn_len_,
int min_full_attn_seqlen_)
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int contextual_seqlen_,
int max_uih_len_,
int max_attn_len_,
int min_full_attn_seqlen_)
: contextual_seqlen(contextual_seqlen_),
max_uih_len(max_uih_len_),
max_attn_len(max_attn_len_),
min_full_attn_seqlen(min_full_attn_seqlen_){};
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
// i_y is the start offset of the current tile along the seqlen_q dimension
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
max_uih_len = seqlen_;
contextual_seqlen = contextual_seqlen_;
if(i_y < contextual_seqlen)
return ck_tile::make_tuple(0, max_uih_len);
max_attn_len = max_attn_len_;
min_full_attn_seqlen = min_full_attn_seqlen_;
if constexpr(!kUseCausal)
{
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
{
return ck_tile::make_tuple(0, max_uih_len);
}
else
{
index_t x_start = max(0, i_y - max_attn_len);
index_t x_end = i_y + YTile + max_attn_len;
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
max_uih_len -= num_target;
};
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
};
}
else // kUseCausal && kUseLocal
{
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
{
return ck_tile::make_tuple(0, max_uih_len);
}
else
{
index_t x_end = i_y + YTile + max_attn_len;
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target)
return ck_tile::make_tuple(0, x_end);
};
};
}
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
{
max_uih_len = seqlen_;
contextual_seqlen = contextual_seqlen_;
if(row >= max_uih_len || col >= max_uih_len)
return false;
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
max_uih_len -= num_target;
if(row < contextual_seqlen)
return true;
bool result = false;
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = abs(row - col) <= max_attn_len;
if(min_full_attn_seqlen > 0)
result = result || (row >= max_uih_len - min_full_attn_seqlen);
return result;
};
};
template <bool kUseCausal>
struct HstuBlockMaskNoLocal
{
static constexpr bool kUseLocal = false;
static constexpr bool IsMasking = kUseCausal;
int contextual_seqlen;
int max_uih_len;
CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int contextual_seqlen_, int max_uih_len_)
: contextual_seqlen(contextual_seqlen_), max_uih_len(max_uih_len_){};
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
@@ -59,40 +117,9 @@ struct HstuBlockMasking
if(i_y < contextual_seqlen)
return ck_tile::make_tuple(0, max_uih_len);
if constexpr(kUseCausal && !kUseLocal)
{
index_t x_end =
min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y
index_t x_end = min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y
return ck_tile::make_tuple(0, x_end);
}
else if constexpr(!kUseCausal && kUseLocal)
{
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
{
return ck_tile::make_tuple(0, max_uih_len);
}
else
{
index_t x_start = max(0, i_y - max_attn_len);
index_t x_end = i_y + YTile + max_attn_len;
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
};
}
else // kUseCausal && kUseLocal
{
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
{
return ck_tile::make_tuple(0, max_uih_len);
}
else
{
index_t x_end = i_y + YTile + max_attn_len;
return ck_tile::make_tuple(0, x_end);
};
};
return ck_tile::make_tuple(0, x_end);
};
}
@@ -106,21 +133,7 @@ struct HstuBlockMasking
if constexpr(IsMasking)
{
bool result = false;
if constexpr(kUseLocal)
{
if constexpr(kUseCausal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = abs(row - col) <= max_attn_len;
if(min_full_attn_seqlen > 0)
result = (row >= max_uih_len - min_full_attn_seqlen);
}
else
{
result = (row >= col);
};
bool result = (row >= col);
return result;
}
@@ -129,4 +142,40 @@ struct HstuBlockMasking
};
};
template <bool kUseCausal, bool kUseLocal>
struct HstuBlockMasking
{
using Type = std::conditional_t<kUseLocal,
HstuBlockMaskWithLocal<kUseCausal>,
HstuBlockMaskNoLocal<kUseCausal>>;
};
template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
int contextual_seqlen_,
int num_target,
int max_attn_len_,
int min_full_attn_seqlen_)
{
auto max_uih_len_ = seqlen_;
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
max_uih_len_ -= num_target;
return HstuBlockMaskType{
contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_};
};
template <typename HstuBlockMaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
{
auto max_uih_len_ = seqlen_;
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
max_uih_len_ -= num_target;
return HstuBlockMaskType{contextual_seqlen_, max_uih_len_};
};
} // namespace ck_tile

View File

@@ -33,8 +33,8 @@ template <typename InOutDataType,
bool kUseLocal>
struct reference_hstu_attention
{
using HstuMask = HstuBlockMasking<kUseCausal, kUseLocal>;
static constexpr bool kHasMask = kUseCausal || kUseLocal;
using HstuMask = typename HstuBlockMasking<kUseCausal, kUseLocal>::Type;
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
@@ -90,11 +90,9 @@ struct reference_hstu_attention
assert(num_tagets.empty() || num_targets.size() == num_batch);
auto silu = [](CompDataType x) {
auto one = ck_tile::type_convert<CompDataType>(1.0f);
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
auto sigmod_val = one / (one + std::exp(-x));
return sigmod_val * x;
return x / (one + std::exp(-x));
};
auto f = [&](auto i_batch, auto i_head) {
@@ -104,11 +102,12 @@ struct reference_hstu_attention
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
HstuMask mask = [&]() {
if constexpr(kHasMask)
return HstuMask{
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen};
if constexpr(kHasLocalMask)
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen);
else
return HstuMask{seqlen, contextual_seqlen, num_target};
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
seqlen, contextual_seqlen, num_target);
}();
// for all rows in the batch