mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CKTILE] FMHA fwd trload lse fix (#3046)
* enable storelse for fmha_fwd_trload kernel * fix lse in trload * fix the mask related bug
This commit is contained in:
@@ -211,10 +211,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
@@ -256,8 +253,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem>());
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp,
|
||||
{physical_seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
@@ -289,8 +288,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp,
|
||||
{physical_seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
|
||||
@@ -393,7 +394,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
{
|
||||
if(i_total_loops == (num_total_loop - 1))
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
const auto k_origin =
|
||||
make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
@@ -410,7 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
|
||||
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
@@ -602,10 +604,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
}
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
// finally, O
|
||||
@@ -717,10 +716,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
// Note: here occ are all cleard, return it
|
||||
@@ -765,8 +761,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
|
||||
const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem, true>());
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp,
|
||||
{physical_seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem, true>());
|
||||
|
||||
auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<KDataType* __restrict__>(smem_ptrk0),
|
||||
@@ -801,8 +799,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
Policy::template MakeSRegTileDistribution<Problem>());
|
||||
|
||||
// V tile in LDS
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp,
|
||||
{physical_seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
|
||||
@@ -901,7 +901,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
{
|
||||
if(i_total_loops == (num_total_loop - 1))
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
const auto k_origin =
|
||||
make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&,
|
||||
@@ -918,7 +919,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops, 0);
|
||||
const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
|
||||
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
|
||||
@@ -1146,10 +1147,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
}
|
||||
});
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
// finally, O
|
||||
|
||||
Reference in New Issue
Block a user