mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Split HstuBlockMasking into HstuBlockMaskWithLocal and HstuBlockMaskNoLocal to save vgprs for non-local situations
This commit is contained in:
@@ -162,8 +162,8 @@ static void show_hstu_attention_fwd_param(std::ostream& os, HstuAttentionFwdPara
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
double rtol = 2e-2;
|
||||
double atol = 2e-2;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ template <typename InOutDataType,
|
||||
struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
@@ -44,7 +46,7 @@ struct HstuAttentionFwdKernel
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
|
||||
using HstuMask = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::HstuMask>;
|
||||
static constexpr bool kHasMask = HstuMask::IsMasking;
|
||||
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
@@ -124,15 +126,16 @@ struct HstuAttentionFwdKernel
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeKargs
|
||||
: HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
struct HstuAttentionFwdBatchModeKargs : HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
@@ -140,15 +143,16 @@ struct HstuAttentionFwdKernel
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs
|
||||
: HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
struct HstuAttentionFwdJaggModeKargs : HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasLocalMask,
|
||||
HstuAttentionFwdMaskKargs,
|
||||
HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
};
|
||||
@@ -224,7 +228,7 @@ struct HstuAttentionFwdKernel
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
kargs.batch_stride_bias = batch_stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
@@ -366,7 +370,7 @@ struct HstuAttentionFwdKernel
|
||||
kargs.seq_stride_bias = seq_stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
if constexpr(kHasLocalMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
@@ -542,14 +546,15 @@ struct HstuAttentionFwdKernel
|
||||
int num_target = (kargs.num_targets_ptr == nullptr) ? 0 : kargs.num_targets_ptr[i_batch];
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen};
|
||||
if constexpr(kHasLocalMask)
|
||||
return make_hstu_block_mask_with_local<HstuMask>(kargs.seqlen,
|
||||
kargs.contextual_seqlen,
|
||||
num_target,
|
||||
kargs.window_size,
|
||||
kargs.min_full_attn_seqlen);
|
||||
else
|
||||
return HstuMask{kargs.seqlen, kargs.contextual_seqlen, num_target};
|
||||
return make_hstu_block_mask_without_local<HstuMask>(
|
||||
kargs.seqlen, kargs.contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
|
||||
@@ -32,7 +32,7 @@ template <typename InOutDataType,
|
||||
struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
using HstuMask = typename ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
|
||||
@@ -3,14 +3,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
struct HstuBlockMasking
|
||||
template <bool kUseCausal>
|
||||
struct HstuBlockMaskWithLocal
|
||||
{
|
||||
static constexpr bool IsMasking = (kUseCausal || kUseLocal);
|
||||
static constexpr bool kUseLocal = true;
|
||||
static constexpr bool IsMasking = true;
|
||||
|
||||
int contextual_seqlen;
|
||||
int max_uih_len;
|
||||
@@ -18,30 +20,86 @@ struct HstuBlockMasking
|
||||
int max_attn_len;
|
||||
int min_full_attn_seqlen;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskWithLocal(int contextual_seqlen_,
|
||||
int max_uih_len_,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
: contextual_seqlen(contextual_seqlen_),
|
||||
max_uih_len(max_uih_len_),
|
||||
max_attn_len(max_attn_len_),
|
||||
min_full_attn_seqlen(min_full_attn_seqlen_){};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
// i_y is the start offset of the current tile along the seqlen_q dimension
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
max_uih_len = seqlen_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
if(i_y < contextual_seqlen)
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
|
||||
max_attn_len = max_attn_len_;
|
||||
min_full_attn_seqlen = min_full_attn_seqlen_;
|
||||
if constexpr(!kUseCausal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_start = max(0, i_y - max_attn_len);
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
|
||||
max_uih_len -= num_target;
|
||||
};
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
};
|
||||
}
|
||||
else // kUseCausal && kUseLocal
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int seqlen_, int contextual_seqlen_, int num_target)
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
max_uih_len = seqlen_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
if(row >= max_uih_len || col >= max_uih_len)
|
||||
return false;
|
||||
|
||||
max_uih_len -= contextual_seqlen > 0 ? contextual_seqlen - 1 : 0;
|
||||
max_uih_len -= num_target;
|
||||
if(row < contextual_seqlen)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
|
||||
return result;
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal>
|
||||
struct HstuBlockMaskNoLocal
|
||||
{
|
||||
static constexpr bool kUseLocal = false;
|
||||
static constexpr bool IsMasking = kUseCausal;
|
||||
|
||||
int contextual_seqlen;
|
||||
int max_uih_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMaskNoLocal(int contextual_seqlen_, int max_uih_len_)
|
||||
: contextual_seqlen(contextual_seqlen_), max_uih_len(max_uih_len_){};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
@@ -59,40 +117,9 @@ struct HstuBlockMasking
|
||||
if(i_y < contextual_seqlen)
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
|
||||
if constexpr(kUseCausal && !kUseLocal)
|
||||
{
|
||||
index_t x_end =
|
||||
min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y
|
||||
index_t x_end = min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y
|
||||
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else if constexpr(!kUseCausal && kUseLocal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_start = max(0, i_y - max_attn_len);
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
};
|
||||
}
|
||||
else // kUseCausal && kUseLocal
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -106,21 +133,7 @@ struct HstuBlockMasking
|
||||
|
||||
if constexpr(IsMasking)
|
||||
{
|
||||
bool result = false;
|
||||
if constexpr(kUseLocal)
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = (row >= max_uih_len - min_full_attn_seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = (row >= col);
|
||||
};
|
||||
bool result = (row >= col);
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -129,4 +142,40 @@ struct HstuBlockMasking
|
||||
};
|
||||
};
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
struct HstuBlockMasking
|
||||
{
|
||||
using Type = std::conditional_t<kUseLocal,
|
||||
HstuBlockMaskWithLocal<kUseCausal>,
|
||||
HstuBlockMaskNoLocal<kUseCausal>>;
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_hstu_block_mask_with_local(int seqlen_,
|
||||
int contextual_seqlen_,
|
||||
int num_target,
|
||||
int max_attn_len_,
|
||||
int min_full_attn_seqlen_)
|
||||
{
|
||||
auto max_uih_len_ = seqlen_;
|
||||
|
||||
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
|
||||
max_uih_len_ -= num_target;
|
||||
|
||||
return HstuBlockMaskType{
|
||||
contextual_seqlen_, max_uih_len_, max_attn_len_, min_full_attn_seqlen_};
|
||||
};
|
||||
|
||||
template <typename HstuBlockMaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_hstu_block_mask_without_local(int seqlen_, int contextual_seqlen_, int num_target)
|
||||
{
|
||||
auto max_uih_len_ = seqlen_;
|
||||
|
||||
max_uih_len_ -= contextual_seqlen_ > 0 ? contextual_seqlen_ - 1 : 0;
|
||||
max_uih_len_ -= num_target;
|
||||
|
||||
return HstuBlockMaskType{contextual_seqlen_, max_uih_len_};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,8 +33,8 @@ template <typename InOutDataType,
|
||||
bool kUseLocal>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
using HstuMask = HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
static constexpr bool kHasMask = kUseCausal || kUseLocal;
|
||||
using HstuMask = typename HstuBlockMasking<kUseCausal, kUseLocal>::Type;
|
||||
static constexpr bool kHasLocalMask = HstuMask::kUseLocal;
|
||||
|
||||
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
@@ -90,11 +90,9 @@ struct reference_hstu_attention
|
||||
assert(num_tagets.empty() || num_targets.size() == num_batch);
|
||||
|
||||
auto silu = [](CompDataType x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
auto sigmod_val = one / (one + std::exp(-x));
|
||||
|
||||
return sigmod_val * x;
|
||||
return x / (one + std::exp(-x));
|
||||
};
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
@@ -104,11 +102,12 @@ struct reference_hstu_attention
|
||||
int num_target = num_targets.empty() ? 0 : num_targets[i_batch];
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{
|
||||
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen};
|
||||
if constexpr(kHasLocalMask)
|
||||
return ck_tile::make_hstu_block_mask_with_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target, max_attn_len, min_full_attn_seqlen);
|
||||
else
|
||||
return HstuMask{seqlen, contextual_seqlen, num_target};
|
||||
return ck_tile::make_hstu_block_mask_without_local<HstuMask>(
|
||||
seqlen, contextual_seqlen, num_target);
|
||||
}();
|
||||
|
||||
// for all rows in the batch
|
||||
|
||||
Reference in New Issue
Block a user