mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
||||
|
||||
@@ -7,6 +7,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GenericAttentionMaskEnum
|
||||
{
|
||||
NO_MASK = 0,
|
||||
|
||||
// below enum could be causal, or sliding window
|
||||
MASK_FROM_TOP_LEFT = 1,
|
||||
MASK_FROM_BOTTOM_RIGHT = 2,
|
||||
|
||||
// this enum maybe not used by xformer/FA, since it's hard to
|
||||
// specify left/right window for varlen case. put it here for
|
||||
// debug purpose
|
||||
MASK_GENERIC,
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
/* generic Attention Mask Coordinate
|
||||
use x(horizontal axis), y(vertical axis) to describe mask.
|
||||
@@ -188,6 +202,129 @@ struct GenericAttentionMask
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl {
|
||||
template <bool IsMasking_> struct SimplifiedMaskName;
|
||||
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
|
||||
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
// this version only have 2 variation: masking and non-masking
|
||||
// This is more friendly to codegen (e.g. need generate less kernel)
|
||||
// ... with the trade-off that may have more instruction in causal mode
|
||||
template <bool IsMasking_ = true>
|
||||
struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
static constexpr bool IsMasking = IsMasking_; // false will disable masking
|
||||
|
||||
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
|
||||
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
|
||||
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
|
||||
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
|
||||
{
|
||||
}
|
||||
template <typename MaskCoordinates>
|
||||
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
|
||||
: y(mask_coord.at(number<0>{})),
|
||||
x(mask_coord.at(number<1>{})),
|
||||
y_total(mask_coord.at(number<2>{})),
|
||||
x_total(mask_coord.at(number<3>{}))
|
||||
{
|
||||
}
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
|
||||
// TODO: x_end still could be negative, so end-start could be negative(need check)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, x_total);
|
||||
}
|
||||
else
|
||||
{
|
||||
// get the tile start/end range assum we loop over along X tile by tile
|
||||
index_t x_start = [&]() {
|
||||
index_t tmp = max(-y + i_y + 1, 0);
|
||||
return (tmp / XTile) * XTile; // round to tile aligned
|
||||
}();
|
||||
|
||||
// TODO: end could be negative, we ignore clamp here, and let caller to check
|
||||
// ... in which case end-start is negative
|
||||
index_t x_end = [&]() {
|
||||
index_t tmp = min(i_y + YTile - 1 + x, x_total);
|
||||
return ((tmp + XTile - 1) / XTile) * XTile;
|
||||
}();
|
||||
|
||||
return ck_tile::make_tuple(x_start, x_end);
|
||||
}
|
||||
}
|
||||
|
||||
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
|
||||
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
return i_x >= x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// no need to do min/max here, since i_x will never be < 0 or >= x_total
|
||||
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
|
||||
index_t x_end = i_y + x; // this could be larger than x_total, but it's fine
|
||||
|
||||
return i_x < x_start || i_x >= x_end;
|
||||
}
|
||||
}
|
||||
|
||||
// if current tile is at the edge, means need per-pixel mask check.
|
||||
// otherwise no need to check per-pixel
|
||||
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
|
||||
// can be used as a fast-path to decide if do per-pixel check or not
|
||||
template <index_t TileHeight, index_t TileWidth>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// the only case that need do following compare is under kPadSeqLenK
|
||||
// ... for non-masking kernel.
|
||||
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
|
||||
|
||||
// TODO: no need to check begin
|
||||
return (i_x + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_x_end = i_x + TileWidth;
|
||||
index_t i_y_end = i_y + TileHeight;
|
||||
// index_t x_end = min(i_y + x, x_total);
|
||||
|
||||
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
|
||||
bool bottom_left_edge = i_y_end > (i_x + y);
|
||||
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
index_t y, x;
|
||||
index_t y_total, x_total;
|
||||
};
|
||||
|
||||
// TODO: prefer use this function in host code
|
||||
// can convert from the FA style left/right to our generic coordinate
|
||||
// if left_size < 0 && right_size = 0, it is normal causal mask
|
||||
@@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
index_t x = 0, y = 0;
|
||||
// TODO: below should all use sgpr arithmetic
|
||||
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
|
||||
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
|
||||
|
||||
if(is_top_left)
|
||||
{
|
||||
if(left_size < 0)
|
||||
left_size = y_total - 1;
|
||||
if(right_size < 0)
|
||||
right_size = x_total - 1;
|
||||
left_size = left_size < 0 ? left_size_tmp : left_size;
|
||||
right_size = right_size < 0 ? right_size_tmp : right_size;
|
||||
|
||||
x = 1 + right_size;
|
||||
y = left_size + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(left_size < 0)
|
||||
left_size = x_total - 1;
|
||||
if(right_size < 0)
|
||||
right_size = y_total - 1;
|
||||
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
|
||||
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
|
||||
|
||||
x = x_total - y_total + 1 + right_size;
|
||||
y = y_total - x_total + 1 + left_size;
|
||||
}
|
||||
index_t x = 1 + right_size + x_tmp;
|
||||
index_t y = 1 + left_size + y_tmp;
|
||||
|
||||
return ck_tile::make_tuple(y, x, y_total, x_total);
|
||||
}
|
||||
|
||||
template <typename MaskType>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
make_generic_attention_mask_from_lr_window(index_t left_size,
|
||||
index_t right_size,
|
||||
index_t y_total,
|
||||
index_t x_total,
|
||||
bool is_top_left = true)
|
||||
{
|
||||
auto r = make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, is_top_left);
|
||||
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -138,7 +138,9 @@ struct FmhaFwdKernel
|
||||
|
||||
struct FmhaFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t mask_y, mask_x;
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaFwdFP8Kargs
|
||||
@@ -217,8 +219,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
{
|
||||
@@ -262,8 +265,9 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.mask_y = mask_y;
|
||||
kargs.mask_x = mask_x;
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -306,8 +310,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t mask_y,
|
||||
ck_tile::index_t mask_x,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float descale_qk,
|
||||
float descale_sv)
|
||||
{
|
||||
@@ -349,8 +354,9 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.mask_y = mask_y;
|
||||
kargs.mask_x = mask_x;
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -639,7 +645,12 @@ struct FmhaFwdKernel
|
||||
|
||||
FmhaMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k};
|
||||
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
|
||||
else
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaPipelineEnum
|
||||
{
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
QRKSVS_FP8,
|
||||
QSKSVS,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user