mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] FMHA FAv3 scheduling fine-tuning for performance (#2833)
* Re-mapping thread block indices for causal=True kernels * Use more intuitive remap_opt value * Fallback to origin remapping if seqlen_q >= 64K * Use GenericAttentionMask to reduce mask computation * Avoid unnecessary boundary check for IsMasking=false case * Fix wrong kernel entry specifier * Add s_nop to prevent delay wave0-3 * Refine scheduling * Remove unnecessary sched_group_barrier() * Move sched_group_barrier() call to scheduler * Replace inline asm s_setprio with intrinsics * Rephrase comments * Expend some o_acc rescaling insts to avoid SIMD idle * Fix block idx special mapping logic * Tune block index mapping for causal=False cases * Tune block index mapping for causal=True cases * Fix wrong vmcnt() * Remove parameter name * Use boolean option for turn on/off causal mask * Update benchmark_fwd_v3.sh option usages * Add option if compiler support it
This commit is contained in:
@@ -203,27 +203,36 @@ struct GenericAttentionMask
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number<TileHeight>, number<TileWidth>) const
|
||||
{
|
||||
if constexpr(IsLocal)
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
// TODO: no need to check begin
|
||||
return (i_tile_left + TileWidth) > x_total;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
if constexpr(IsLocal)
|
||||
{
|
||||
// check top-right corner > x or left-borrom corner < x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t i_tile_bottom = i_tile_top + TileHeight;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
bool top_right_edge = i_tile_right > (i_tile_top + x);
|
||||
bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
|
||||
bool is_partial_out_of_bound =
|
||||
i_tile_right > x_end; // only consider right-pad for now
|
||||
|
||||
return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
|
||||
}
|
||||
else
|
||||
{
|
||||
// only need to check top-right corner > x
|
||||
index_t i_tile_right = i_tile_left + TileWidth;
|
||||
index_t x_end = min(i_tile_top + x, x_total);
|
||||
|
||||
bool top_right_edge = i_tile_right > x_end;
|
||||
return top_right_edge;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel
|
||||
// ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::index_t window_size_left, window_size_right;
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
ck_tile::index_t remap_opt;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
@@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t remap_opt)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel
|
||||
kargs.window_size_left = window_size_left;
|
||||
kargs.window_size_right = window_size_right;
|
||||
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
|
||||
kargs.remap_opt = remap_opt;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return dim3(nhead_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
batch_size_);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
|
||||
{
|
||||
if(remap_option < 1)
|
||||
{
|
||||
return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
|
||||
}
|
||||
|
||||
int32_t remapped_tg_idx = tg_idx;
|
||||
int32_t remapped_tg_idy = tg_idy;
|
||||
|
||||
if(remap_option == 2)
|
||||
{ // special remapping
|
||||
int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
|
||||
int32_t tmp1 = tmp0 & 0x7;
|
||||
|
||||
remapped_tg_idx = tmp0 >> 3;
|
||||
remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
|
||||
}
|
||||
else
|
||||
{ // normal remapping
|
||||
int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
|
||||
int32_t tgs_cu_id = remapped_tg_idx >> 3;
|
||||
|
||||
if(tgs_cu_id < cus_per_xdim_per_xcc)
|
||||
{
|
||||
int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
|
||||
int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
|
||||
|
||||
remapped_tg_idx = new_tg_idx;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(remapped_tg_idx, remapped_tg_idy);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
|
||||
// FmhaPipeline::kN1);
|
||||
|
||||
// assume that num_tile_n1 is always 1
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_block = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -68,11 +72,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -81,7 +93,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -115,7 +131,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1) {}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -126,11 +146,19 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3) {}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0) {}
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
@@ -139,7 +167,11 @@ struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2) {}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
@@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
|
||||
{
|
||||
float result;
|
||||
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
|
||||
{
|
||||
fp16x2_t result;
|
||||
@@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
statically_indexed_array<sp_compute_type, 2> sp;
|
||||
|
||||
decltype(gemm_1.MakeCBlockTile()) o_acc;
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd()
|
||||
constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
|
||||
// instructions should we move to fmha_alu1()
|
||||
static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
|
||||
|
||||
@@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
// K_mem_su_ld_insts = 1 for 32 x 128
|
||||
// V_mem_su_ld_insts = 1 for 128 x 32
|
||||
static constexpr int K_mem_su_ld_insts = 1;
|
||||
static constexpr int V_mem_su_ld_insts = 1;
|
||||
constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
|
||||
constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
|
||||
|
||||
auto K_mem_load = [&](auto k_lds_write_idx) {
|
||||
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
|
||||
@@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline
|
||||
|
||||
auto V_mem_load = [&](auto v_lds_write_idx) {
|
||||
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
/// FIXME: use the future-predicting method to move the window
|
||||
move_tile_window(v_dram_window, {kK1, 0});
|
||||
@@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline
|
||||
#else
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
#endif
|
||||
// update partial o_acc [0, 2)
|
||||
static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
|
||||
// l{j}
|
||||
/// Note: The compiler keeps moving the following instructions elsewhere because 'l'
|
||||
/// is first consumed later. To anchor them here, we rewrite the final addition in
|
||||
/// inline assembly to create a dependency, forcing the dependent instructions to
|
||||
/// be emitted at this point.
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
@@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline
|
||||
l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
|
||||
});
|
||||
|
||||
// update partial o_acc [2, fmha_alu_D_reg_cnt)
|
||||
static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}(
|
||||
[&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
|
||||
// update partial o_acc [0, fmha_alu_D_reg_cnt)
|
||||
static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) {
|
||||
o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
|
||||
});
|
||||
|
||||
/// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite
|
||||
/// the cast_tile() call into inline asm to force the conversion instructions to be
|
||||
/// generated here. The fmha_alu1() call should be placed at the end of a phase.
|
||||
/// Note: The compiler keeps sinking the conversion instructions because the
|
||||
/// result 'p' is only consumed later. To anchor them here, we rewrite
|
||||
/// the cast_tile() call as inline assembly, forcing the conversions to be
|
||||
/// emitted at this point.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
|
||||
float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
|
||||
@@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
|
||||
}
|
||||
});
|
||||
|
||||
/// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly
|
||||
/// can interfere with the behavior of sched_group_barrier(), so ending the phase here
|
||||
/// avoids unintended reordering.
|
||||
};
|
||||
|
||||
auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
|
||||
@@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<1>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// phase2
|
||||
ASM_MARKER("phase2 Wave0-3");
|
||||
@@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 0");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
@@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p01_reg_idx, gemm0);
|
||||
fmha_alu1(xdl_SP_p23_reg_idx);
|
||||
|
||||
@@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
fmha_mask(xdl_SP_p01_reg_idx);
|
||||
|
||||
Scheduler::schedule(cl_p, number<2>{});
|
||||
kv_token_start += kN0;
|
||||
if(num_total_loop <= ++i_total_loops)
|
||||
{
|
||||
@@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
asm volatile("s_nop 1");
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
cl_calc(xdl_SP_p23_reg_idx, gemm1);
|
||||
|
||||
Scheduler::schedule(cl_p, number<3>{});
|
||||
@@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
auto ps_pi = number<1>{} - d;
|
||||
auto V_lds_rd_idx = ps_pi;
|
||||
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
if(1 < num_total_loop)
|
||||
{
|
||||
s_waitcnt_vmcnt<K_mem_su_ld_insts>();
|
||||
}
|
||||
else
|
||||
{
|
||||
s_waitcnt_vmcnt<0>();
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
V_lds_load(V_lds_rd_idx);
|
||||
@@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline
|
||||
V_mem_load(number<1>{}); // V1
|
||||
K_lds_load(number<1>{}); // K1
|
||||
|
||||
asm volatile("s_setprio 0");
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<0>{}))
|
||||
;
|
||||
}
|
||||
if(warp_group_id != 0)
|
||||
{
|
||||
asm volatile("s_setprio 1");
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
while(core_loop(number<1>{}))
|
||||
;
|
||||
@@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user