remove all lds bankconflict with xor layouts

This commit is contained in:
aska-0096
2025-07-30 12:25:33 +00:00
parent 8dacc35c4c
commit 69890afc98
5 changed files with 225 additions and 51 deletions

View File

@@ -582,6 +582,15 @@ struct FmhaFwdDecodeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// TODO: Refine the logical here.
// In Decode case
// 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
// 2. limit the LDS usage, as we want higher occupancy
// In Prefill case
// 1. we expect KV data reused by different ThreadGroups, use cache
// 2. use more LDS, as we want better memory latency hiding
// If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the cache
constexpr bool PrefillCase = FmhaPipeline::kM0 == 128;
// divide problem
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
@@ -710,7 +719,9 @@ struct FmhaFwdDecodeKernel
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
// hdim_q)
// We expect Q data reuse among different KVSplited in decode case.
const auto view = make_naive_tensor_view<address_space_enum::global>(
const auto view = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
q_ptr,
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
@@ -727,7 +738,9 @@ struct FmhaFwdDecodeKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
return make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
@@ -753,20 +766,44 @@ struct FmhaFwdDecodeKernel
}();
const auto make_k_dram = [&](const KDataType* data, index_t height) {
// We don't expect K data reuse among different blocks in decode case.
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
const auto k_dram_pad = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
const auto k_dram_unmerged = transform_tensor_view(
k_dram_pad,
make_tuple(make_pass_through_transform(height),
make_unmerge_transform(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto k_dram_permuted = transform_tensor_view(
k_dram_unmerged,
make_tuple(
make_xor_transform(make_tuple(
height, number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
make_pass_through_transform(number<FmhaPipeline::kAlignmentK>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
k_dram_permuted,
make_tuple(make_pass_through_transform(height),
make_merge_transform_v3_division_mod(make_tuple(
number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{},
number<FmhaPipeline::kAlignmentK>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
@@ -780,19 +817,47 @@ struct FmhaFwdDecodeKernel
}();
const auto make_v_dram = [&](const VDataType* data, index_t length) {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::SYSTEM_NT1>(
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
const auto v_dram_pad = pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenK, false>{});
// TODO: Add kVHeadDim
// TrLoad Performed in 16x4/16x8/16x16 unit, the fast dimension is 16 elements
constexpr auto TrLoadFastDimLength = 16;
const auto v_dram_unmerged = transform_tensor_view(
v_dram_pad,
make_tuple(make_pass_through_transform(length),
make_unmerge_transform(
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_dram_permuted = transform_tensor_view(
v_dram_unmerged,
make_tuple(make_xor_transform(make_tuple(
length, number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{})),
make_pass_through_transform(number<TrLoadFastDimLength>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_view(
v_dram_permuted,
make_tuple(make_pass_through_transform(length),
make_merge_transform_v3_division_mod(
make_tuple(number<FmhaPipeline::kQKHeaddim / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
};
const auto v_dram = [&]() {
@@ -945,7 +1010,7 @@ struct FmhaFwdDecodeKernel
}();
auto o_acc_tile = [&, i_split_ = i_split]() {
if constexpr(FmhaPipeline::kM0 == 128)
if constexpr(PrefillCase)
{
// allocate double lds
// add __restrict__ here to avoid aliasing

View File

@@ -253,12 +253,18 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem>());
auto k_lds = make_tensor_view<address_space_enum::lds>(
auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_write_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto k_lds_read_view = make_tensor_view<address_space_enum::lds>(
static_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem, false, true>());
auto k_lds_write_window =
make_tile_window(k_lds_write_view,
Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0});
auto k_lds_read_window =
make_tile_window(k_lds,
make_tile_window(k_lds_read_view,
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>());
@@ -280,16 +286,23 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tensor_view<address_space_enum::lds>(
auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeS<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_write_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto v_lds_read_view = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeS<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem, true>());
auto v_lds_write_window =
make_tile_window(v_lds_write_view,
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0});
auto v_lds_read_window =
make_tile_window(v_lds,
make_tile_window(v_lds_read_view,
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>());
@@ -745,29 +758,38 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution<Problem, true>());
auto k_lds = make_tuple(make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()),
make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk1),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()));
auto k_lds_write_view =
make_tuple(make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()),
make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk1),
Policy::template MakeKLdsBlockDescriptor<Problem, true>()));
auto k_lds_read_view =
make_tuple(make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk0),
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()),
make_tensor_view<address_space_enum::lds>(
static_cast<KDataType* __restrict__>(smem_ptrk1),
Policy::template MakeKLdsBlockDescriptor<Problem, true, true>()));
auto k_lds_write_windows =
make_tuple(make_tile_window(
k_lds.at(I0),
k_lds_write_view.at(I0),
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
{0, 0}),
make_tile_window(
k_lds.at(I1),
k_lds_write_view.at(I1),
Policy::template MakeKLdsBlockDescriptor<Problem, true>().get_lengths(),
{0, 0}));
auto k_lds_read_windows =
make_tuple(make_tile_window(k_lds.at(I0),
make_tuple(make_tile_window(k_lds_read_view.at(I0),
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>()),
make_tile_window(k_lds.at(I1),
make_tile_window(k_lds_read_view.at(I1),
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKRegTileDistribution<Problem>()));
@@ -789,7 +811,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tuple(
auto v_lds_write_view = make_tuple(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem>()),
@@ -797,20 +819,28 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
Policy::template MakeVLdsBlockDescriptor<Problem>()));
auto v_lds_read_view = make_tuple(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
Policy::template MakeVLdsBlockDescriptor<Problem, true>()),
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv1)),
Policy::template MakeVLdsBlockDescriptor<Problem, true>()));
auto v_lds_write_windows = make_tuple(
make_tile_window(v_lds.at(I0),
make_tile_window(v_lds_write_view.at(I0),
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}),
make_tile_window(v_lds.at(I1),
make_tile_window(v_lds_write_view.at(I1),
Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0}));
auto v_lds_read_windows =
make_tuple(make_tile_window(v_lds.at(I0),
make_tuple(make_tile_window(v_lds_read_view.at(I0),
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>()),
make_tile_window(v_lds.at(I1),
make_tile_window(v_lds_read_view.at(I1),
make_tuple(number<kK1>{}, number<kN1>{}),
{0, 0},
Policy::template MakeVRegTileDistribution<Problem>()));

View File

@@ -224,7 +224,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
return q_lds_block_desc;
}
template <typename Problem, bool LoadOnce = false>
template <typename Problem, bool LoadOnce = false, bool Xor = false>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
@@ -233,16 +233,53 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr auto k_lds_block_desc =
make_naive_tensor_descriptor(make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto k_lds_block_desc_unmerged = transform_tensor_descriptor(
k_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto k_lds_block_desc_permuted = transform_tensor_descriptor(
k_lds_block_desc_unmerged,
make_tuple(make_xor_transform(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
k_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
}
}();
return k_lds_block_desc;
}
template <typename Problem>
template <typename Problem, bool Xor = false>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
@@ -250,11 +287,53 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto v_lds_block_desc =
make_naive_tensor_descriptor(make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = [&]() {
if constexpr(Xor)
{
constexpr auto TrLoadFastDimLength = 16;
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
const auto v_lds_block_desc_unmerged = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_unmerge_transform(
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
const auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_unmerged,
make_tuple(
make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / TrLoadFastDimLength>{})),
make_pass_through_transform(number<TrLoadFastDimLength>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<kKPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<kNPerBlock / TrLoadFastDimLength>{},
number<TrLoadFastDimLength>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
make_tuple(number<kNPerBlock>{}, number<1>{}),
number<kKPack>{},
number<1>{});
}
}();
return v_lds_block_desc;
}