mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] wa prec, remove sgpr offset for inline asm (#1356)
* wa prec, remove sgpr offset for inline asm * macro for set tile * ignore unused param if no kernel instances in host API * fix more prec issue * cache buffer resource * fix * support pre-nop * clear tile by vector type members * add workaround to reduce scratch memory * conditionally enable workaround code * enable workaround start from certain build version * fallback set_tile() implementation from certain build version * undo template argument changes * put dummy asm in load_raw() * fix comments, refactor s_nop inside buffer_load --------- Co-authored-by: PoYen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
|
||||
@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
q_dram_window.init_raw();
|
||||
|
||||
// TODO: we use async Copy for K, which is inline asm
|
||||
// a side effect is we have to use inline asm for q as well
|
||||
@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dram_window.init_raw();
|
||||
constexpr auto k_oob_ck = bool_constant<true>{};
|
||||
constexpr auto k_pre_np = [&]() {
|
||||
if constexpr(kPadSeqLenK &&
|
||||
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
return bool_constant<false>{};
|
||||
}();
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window);
|
||||
k_dram_window,
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
{
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
k_dram_window =
|
||||
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail
|
||||
|
||||
Reference in New Issue
Block a user