mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
[CK_TILE] support alibi (#1269)
* add alibi support
* fix code
* update code based on comment
* Support more hdim
* fix fp8 bias
* support seqlen_k=0 case
* remove unused printf
* fix format
---------
Co-authored-by: rocking <ChunYu.Lai@amd.com>
[ROCm/composable_kernel commit: 851c3ed157]
This commit is contained in:
@@ -154,3 +154,8 @@
|
||||
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
// TODO: better solve this inside compiler
|
||||
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
|
||||
{
|
||||
// TODO: this is hacky, we use u16
|
||||
return __builtin_amdgcn_sad_u16(x, y, acc);
|
||||
}
|
||||
|
||||
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
|
||||
{
|
||||
return (x > y ? (x - y) : (y - x)) + acc;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
|
||||
37
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
Normal file
37
include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockAttentionBiasEnum
|
||||
{
|
||||
NO_BIAS = 0,
|
||||
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
|
||||
ALIBI = 2, // bias computed with position encoding, applied after scale
|
||||
};
|
||||
|
||||
template <BlockAttentionBiasEnum>
|
||||
struct BlockAttentionBiasEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
|
||||
{
|
||||
static constexpr const char* name = "bias";
|
||||
};
|
||||
template <>
|
||||
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
|
||||
{
|
||||
static constexpr const char* name = "alibi";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
189
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
Normal file
189
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
Normal file
@@ -0,0 +1,189 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct PositionEncodingEnum
|
||||
{
|
||||
NO = 0,
|
||||
ALIBI = 1,
|
||||
};
|
||||
|
||||
/*
|
||||
VERTICAL:
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
|
||||
TOP_LEFT:
|
||||
[0] 1 2 3 4 5
|
||||
1 [0] 1 2 3 4
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
|
||||
FROM_BOTTOM_RIGHT:
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
4 3 2 1 [0] 1
|
||||
5 4 3 2 1 [0]
|
||||
*/
|
||||
|
||||
enum struct AlibiMode
|
||||
{
|
||||
VERTICAL = 0,
|
||||
FROM_TOP_LEFT = 1, // keep sync with mask enum
|
||||
FROM_BOTTOM_RIGHT = 2,
|
||||
};
|
||||
|
||||
template <typename DataType, bool RowMajor = true>
|
||||
struct Alibi
|
||||
{
|
||||
// RowMajor here means if pixel within the same thread are along the row, or col
|
||||
// this may impact the performance of update(), while the result are the same.
|
||||
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
|
||||
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
|
||||
index_t y_total_,
|
||||
index_t x_total_,
|
||||
AlibiMode mode_ = AlibiMode::VERTICAL)
|
||||
{
|
||||
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope;
|
||||
|
||||
shift_left_up = [&]() {
|
||||
if(RowMajor)
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
|
||||
}
|
||||
}();
|
||||
shift_right_down = [&]() {
|
||||
if(RowMajor)
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
|
||||
}
|
||||
}();
|
||||
mode = mode_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
|
||||
{
|
||||
if constexpr(RowMajor)
|
||||
{
|
||||
// at least 3 instructions per row
|
||||
index_t current_zero_point =
|
||||
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
|
||||
|
||||
// for every threads, most of the pixels are along the row, below operation should be
|
||||
// the main hot spot.
|
||||
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
|
||||
bit_cast<uint32_t>(col_idx + shift_left_up),
|
||||
0));
|
||||
pixel += slope * position;
|
||||
}
|
||||
else
|
||||
{
|
||||
// at least 3 instructions per col;
|
||||
index_t current_zero_point = mode == AlibiMode::VERTICAL
|
||||
? row_idx + col_idx + shift_right_down
|
||||
: col_idx + shift_right_down;
|
||||
|
||||
// for every threads, most of the pixels are along the col, below operation should be
|
||||
// the main hot spot.
|
||||
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
|
||||
bit_cast<uint32_t>(row_idx + shift_left_up),
|
||||
0));
|
||||
pixel += slope * position;
|
||||
}
|
||||
}
|
||||
|
||||
DataType slope; // float?
|
||||
index_t shift_left_up; // always possitive
|
||||
index_t shift_right_down; // always possitive
|
||||
AlibiMode mode;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct EmptyPositionEncoding
|
||||
{
|
||||
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
// local is left_size >=0 or right_size >=0
|
||||
template <typename DataType, bool RowMajor = true>
|
||||
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
|
||||
index_t window_left_size,
|
||||
index_t window_right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
GenericAttentionMaskEnum mask_enum)
|
||||
{
|
||||
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
|
||||
// totally OK to use constexpr
|
||||
bool is_causal = window_left_size < 0 && window_right_size == 0;
|
||||
AlibiMode alibi_mode =
|
||||
is_causal ? AlibiMode::VERTICAL
|
||||
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
|
||||
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
|
||||
}
|
||||
|
||||
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
|
||||
// Do we need a device version?
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
|
||||
{
|
||||
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
|
||||
float start = std::powf(
|
||||
static_cast<float>(2),
|
||||
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
|
||||
|
||||
std::vector<DataType> rtn;
|
||||
for(auto i = 0; i < n; i++)
|
||||
{
|
||||
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
|
||||
}
|
||||
return rtn;
|
||||
};
|
||||
if(is_power_of_two_integer(nheads))
|
||||
{
|
||||
// power of 2 calculation
|
||||
return get_slopes_power_of_2(nheads);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
|
||||
auto v0 = get_slopes_power_of_2(closest_power_of_2);
|
||||
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
|
||||
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
|
||||
std::vector<DataType> sliced;
|
||||
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
|
||||
{
|
||||
if(i % 2 == 0)
|
||||
sliced.push_back(vec[i]);
|
||||
}
|
||||
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
|
||||
return sliced_2;
|
||||
}(v1, nheads - closest_power_of_2);
|
||||
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
|
||||
return v0;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -33,6 +34,7 @@ struct FmhaFwdKernel
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
@@ -41,7 +43,7 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -81,7 +83,8 @@ struct FmhaFwdKernel
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -136,6 +139,13 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdAlibiKargs
|
||||
{
|
||||
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
|
||||
const void* alibi_slope_ptr;
|
||||
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
|
||||
};
|
||||
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
@@ -162,7 +172,11 @@ struct FmhaFwdKernel
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaFwdBatchModeBiasKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
|
||||
FmhaFwdAlibiKargs,
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
|
||||
@@ -175,7 +189,11 @@ struct FmhaFwdKernel
|
||||
|
||||
struct FmhaFwdGroupModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaFwdCommonBiasKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
|
||||
FmhaFwdAlibiKargs,
|
||||
FmhaFwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
|
||||
@@ -255,13 +273,18 @@ struct FmhaFwdKernel
|
||||
batch_stride_v,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
kargs.batch_stride_bias = batch_stride_bias;
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
kargs.alibi_slope_ptr = bias_ptr;
|
||||
kargs.alibi_slope_stride = stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
@@ -345,12 +368,17 @@ struct FmhaFwdKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
kargs.alibi_slope_ptr = bias_ptr;
|
||||
kargs.alibi_slope_stride = stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size_left = window_size_left;
|
||||
@@ -421,14 +449,10 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_bias = key_start;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
@@ -461,7 +485,7 @@ struct FmhaFwdKernel
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
@@ -585,7 +609,7 @@ struct FmhaFwdKernel
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto bias_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
const BiasDataType* bias_ptr =
|
||||
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
|
||||
@@ -654,6 +678,39 @@ struct FmhaFwdKernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
// WA i_batch capture structure binding before c++20
|
||||
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
// data loading, shared by entire wg
|
||||
// TODO: how to use s_read?
|
||||
SaccDataType slope =
|
||||
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
return make_alibi_from_lr_mask<SaccDataType, true>(slope,
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Alibi<SaccDataType, true>{
|
||||
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyPositionEncoding<SaccDataType>{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
@@ -672,6 +729,7 @@ struct FmhaFwdKernel
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
@@ -683,6 +741,7 @@ struct FmhaFwdKernel
|
||||
bias_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
|
||||
QSKSVS,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
struct BlockFmhaPipelineEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qr";
|
||||
};
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
|
||||
{
|
||||
static constexpr const char* name = "qr_async";
|
||||
};
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
|
||||
{
|
||||
static constexpr const char* name = "qs";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Traits::kHasBias;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking)
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
|
||||
FmhaMask::IsMasking)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 64)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && kHasBias)
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 2;
|
||||
else
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && kHasBias)
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
// check early exit
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
__builtin_amdgcn_sched_barrier(1);
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
/// consideration. alibi does not have this problem
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
|
||||
FmhaMask mask,
|
||||
PositionEncoding /*position_encoding*/,
|
||||
float scale_s,
|
||||
float descale_qk,
|
||||
float descale_sv,
|
||||
@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
|
||||
k_block_tile = load_tile(k_dram_window);
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
}
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, softmax
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale_s;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
|
||||
}
|
||||
@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
|
||||
}
|
||||
@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kHasBias_,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLSE_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
@@ -21,7 +22,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
|
||||
Reference in New Issue
Block a user