mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)
* Re-mapping thread block indices for causal=True kernels
* Use more intuitive remap_opt value
* Fallback to origin remapping if seqlen_q >= 64K
* Use GenericAttentionMask to reduce mask computation
* Avoid unnecessary boundary check for IsMasking=false case
* Fix wrong kernel entry specifier
* Add s_nop to prevent delay wave0-3
* Refine scheduling
* Remove unnecessary sched_group_barrier()
* Move sched_group_barrier() call to scheduler
* Replace inline asm s_setprio with intrinsics
* Rephrase comments
* Expend some o_acc rescaling insts to avoid SIMD idle
* Fix block idx special mapping logic
* Tune block index mapping for causal=False cases
* Tune block index mapping for causal=True cases
* Fix wrong vmcnt()
* Remove parameter name
* Use boolean option for turn on/off causal mask
* Update benchmark_fwd_v3.sh option usages
* Add option if compiler support it
[ROCm/composable_kernel commit: 7fbc9d6c97]
This commit is contained in:
@@ -213,8 +213,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
|
||||
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
|
||||
if(HAS_DISABLE_PACKED_FP32)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-mllvm --amdgpu-disable-packed-fp32=1
|
||||
)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
|
||||
-DCK_TILE_DISABLE_PACKED_FP32=1
|
||||
)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
@@ -45,18 +45,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("v", "1", "0:no verify, 1:verify")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
@@ -109,7 +98,16 @@ struct Problem
|
||||
softmax_scale = args.get_float("scale_s");
|
||||
if(softmax_scale == .0f)
|
||||
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
const auto is_causal = args.get_bool("causal");
|
||||
if(is_causal)
|
||||
{
|
||||
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
mask = mask_info::decode("0", seqlen_q, seqlen_k);
|
||||
}
|
||||
|
||||
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
|
||||
@@ -34,7 +34,8 @@ struct fmha_fwd_v3_args
|
||||
|
||||
index_t window_size_left;
|
||||
index_t window_size_right;
|
||||
index_t mask_type;
|
||||
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
||||
// window_size_right == 0).
|
||||
|
||||
const void* q_ptr;
|
||||
index_t stride_q;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
@@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using fmha_mask = SimplifiedGenericAttentionMask<IsMasking>;
|
||||
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using fmha_pipeline_problem =
|
||||
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
@@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits
|
||||
template <typename Kernel>
|
||||
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
|
||||
/// maximizes the kernel's performance.
|
||||
int remap_opt = 2;
|
||||
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
|
||||
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
|
||||
{
|
||||
if(65536 <= args.seqlen_q)
|
||||
{
|
||||
remap_opt = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
remap_opt = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
@@ -140,7 +157,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
remap_opt);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
@@ -8,22 +8,16 @@ for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
if [ $causal -eq 0 ]; then
|
||||
mask=0
|
||||
else
|
||||
mask=b:-1,0
|
||||
fi
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -203,27 +203,36 @@ struct GenericAttentionMask
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(IsLocal)
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
// TODO: no need to check begin
|
||||
return (i_tile_left + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound =
|
||||
i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
ck_tile::index_t remap_opt;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
|
||||
{
|
||||
if(remap_option < 1)
|
||||
{
|
||||
return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
|
||||
}
|
||||
|
||||
int32_t remapped_tg_idx = tg_idx;
|
||||
int32_t remapped_tg_idy = tg_idy;
|
||||
|
||||
if(remap_option == 2)
|
||||
{ // special remapping
|
||||
int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
|
||||
int32_t tmp1 = tmp0 & 0x7;
|
||||
|
||||
remapped_tg_idx = tmp0 >> 3;
|
||||
remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
|
||||
}
|
||||
else
|
||||
{ // normal remapping
|
||||
int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
|
||||
int32_t tgs_cu_id = remapped_tg_idx >> 3;
|
||||
|
||||
if(tgs_cu_id < cus_per_xdim_per_xcc)
|
||||
{
|
||||
int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
|
||||
int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
|
||||
|
||||
remapped_tg_idx = new_tg_idx;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(remapped_tg_idx, remapped_tg_idy);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
// assume that num_tile_n1 is always 1
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -68,11 +72,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -81,7 +93,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -115,7 +131,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -126,11 +146,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -139,7 +167,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
|
||||
{
|
||||
float result;
|
||||
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
|
||||
{
|
||||
fp16x2_t result;
|
||||
@@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
statically_indexed_array<sp_compute_type, 2> sp;
|
||||
|
||||
decltype(gemm_1.MakeCBlockTile()) o_acc;
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
|
||||
// instructions should we move to fmha_alu1()
|
||||
static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
|
||||
|
||||
@@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
// K_mem_su_ld_insts = 1 for 32 x 128
|
||||
// V_mem_su_ld_insts = 1 for 128 x 32
|
||||
static constexpr int K_mem_su_ld_insts = 1;
|
||||
static constexpr int V_mem_su_ld_insts = 1;
|
||||
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
@@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
@@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline
|
||||
#else
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
#endif
|
||||
// update partial o_acc [0, 2)
|
||||
static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
|
||||
// l{j}
|
||||
/// Note: The compiler keeps moving the following instructions elsewhere because 'l'
|
||||
/// is first consumed later. To anchor them here, we rewrite the final addition in
|
||||
/// inline assembly to create a dependency, forcing the dependent instructions to
|
||||
/// be emitted at this point.
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline
|
||||
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
|
||||
});
|
||||
|
||||
// update partial o_acc [2, fmha_alu_D_reg_cnt)
|
||||
static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
// update partial o_acc [0, fmha_alu_D_reg_cnt)
|
||||
static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
|
||||
o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
|
||||
});
|
||||
|
||||
/// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite
|
||||
/// the cast_tile() call into inline asm to force the conversion instructions to be
|
||||
/// generated here. The fmha_alu1() call should be placed at the end of a phase.
|
||||
/// Note: The compiler keeps sinking the conversion instructions because the
|
||||
/// result 'p' is only consumed later. To anchor them here, we rewrite
|
||||
/// the cast_tile() call as inline assembly, forcing the conversions to be
|
||||
/// emitted at this point.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
|
||||
@@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
|
||||
}
|
||||
});
|
||||
|
||||
/// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly
|
||||
/// can interfere with the behavior of sched_group_barrier(), so ending the phase here
|
||||
/// avoids unintended reordering.
|
||||
};
|
||||
|
||||
auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
|
||||
@@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// phase2
|
||||
ASM_MARKER("phase2 Wave0-3");
|
||||
@@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 0");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
@@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p01_reg_idx, gemm0);
|
||||
fmha_alu1(xdl_SP_p23_reg_idx);
|
||||
|
||||
@@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
kv_token_start += kN0;
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
{
|
||||
@@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<3>{});
|
||||
@@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
auto ps_pi = number<1>{} - d;
|
||||
auto V_lds_rd_idx = ps_pi;
|
||||
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
if(1 < num_total_loop)
|
||||
{
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
}
|
||||
else
|
||||
{
|
||||
s_waitcnt_vmcnt<0>();
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
V_lds_load(V_lds_rd_idx);
|
||||
@@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
V_mem_load(number<1>{}); // V1
|
||||
K_lds_load(number<1>{}); // K1
|
||||
|
||||
asm volatile("s_setprio 0");
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<0>{}))
|
||||
;
|
||||
}
|
||||
if(warp_group_id != 0)
|
||||
{
|
||||
asm volatile("s_setprio 1");
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<1>{}))
|
||||
;
|
||||
@@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user