[CK_TILE] FMHA BWD Fix Compilation with Bias (#2682)

* [CK_TILE] FMHA BWD Fix Compilation with Bias

* Fix appendkv kApplyRoPE
This commit is contained in:
Yi DING
2025-08-22 10:01:10 +08:00
committed by GitHub
parent 49c6b05c72
commit 4cfa2c7158
5 changed files with 28 additions and 65 deletions

View File

@@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel
{0, i_n0});
// If kApplyRoPe is false, we set the rotary_dim to 0
auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0;
auto rotary_dim = [&]() {
if constexpr(kApplyRoPE)
return kargs.rotary_dim;
else
return 0;
}();
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,

View File

@@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
bias_dram_block_window_tmp.get_bottom_tensor_view()),
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
auto bias_s_lds_read_window =
make_tile_window(bias_lds_read,
make_tuple(number<kM0>{}, number<kN0>{}),
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
bias_lds_write_window.get_window_lengths(),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
@@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
async_load_tile(bias_lds_write_window, bias_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
const auto bias_tile = load_tile(bias_dram_window);
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(

View File

@@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
bias_dram_block_window_tmp.get_bottom_tensor_view()),
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_kv_start},
Policy::template MakeBiasTileDistribution<Problem>());
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
auto bias_s_lds_read_window =
make_tile_window(bias_lds_read,
make_tuple(number<kM0>{}, number<kN0>{}),
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
bias_lds_write_window.get_window_lengths(),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
@@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
async_load_tile(bias_lds_write_window, bias_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
const auto bias_tile = load_tile(bias_dram_window);
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(

View File

@@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor()
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor()
{
return MakeXLdsWriteBlockDescriptor<typename Problem::BiasDataType,
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0>();
return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor<Problem>();
}
template <typename Problem, bool Transposed = false>
@@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor()
{
return MakeXLdsReadBlockDescriptor<typename Problem::BiasDataType,
Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
@@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
kMPerBlock * kNPerBlock / kBlockSize);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M2 = kMPerBlock / M1 / M0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution<Problem>();
}
template <typename BlockGemm>
@@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return sizeof(typename Problem::BiasDataType) *
MakeBiasLdsWriteBlockDescriptor<Problem>().get_element_space_size();
MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}