mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] FMHA BWD Fix Compilation with Bias (#2682)
* [CK_TILE] FMHA BWD Fix Compilation with Bias * Fix appendkv kApplyRoPE
This commit is contained in:
@@ -649,8 +649,12 @@ struct FmhaFwdAppendKVKernel
|
||||
{0, i_n0});
|
||||
|
||||
// If kApplyRoPe is false, we set the rotary_dim to 0
|
||||
auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0;
|
||||
|
||||
auto rotary_dim = [&]() {
|
||||
if constexpr(kApplyRoPE)
|
||||
return kargs.rotary_dim;
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
|
||||
@@ -347,22 +347,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, bias_origin.at(number<1>{})},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
@@ -500,8 +497,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
|
||||
@@ -323,22 +323,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_kv_start},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
|
||||
bias_lds_write_window.get_window_lengths(),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
@@ -490,8 +487,11 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(shuffled_bias_tile, bias_tile);
|
||||
store_tile(bias_lds_write_window, shuffled_bias_tile);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
|
||||
@@ -551,11 +551,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsWriteBlockDescriptor()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor()
|
||||
{
|
||||
return MakeXLdsWriteBlockDescriptor<typename Problem::BiasDataType,
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0>();
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeBiasLdsBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem, bool Transposed = false>
|
||||
@@ -684,13 +682,6 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsReadBlockDescriptor()
|
||||
{
|
||||
return MakeXLdsReadBlockDescriptor<typename Problem::BiasDataType,
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
|
||||
@@ -966,25 +957,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
|
||||
kMPerBlock * kNPerBlock / kBlockSize);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M2 = kMPerBlock / M1 / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>>{});
|
||||
return BlockFmhaBwdPipelineDefaultPolicy::MakeShuffledBiasTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename BlockGemm>
|
||||
@@ -1048,7 +1021,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return sizeof(typename Problem::BiasDataType) *
|
||||
MakeBiasLdsWriteBlockDescriptor<Problem>().get_element_space_size();
|
||||
MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user