mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
bias support
This commit is contained in:
@@ -384,7 +384,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
|
||||
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
|
||||
Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeOGradT<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
|
||||
Policy::template GetSmemSizeD<Problem>()));
|
||||
|
||||
auto biast_lds = make_tensor_view<address_space_enum::lds>(
|
||||
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
|
||||
@@ -555,9 +559,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
block_sync_lds();
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
@@ -571,6 +573,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
st_acc,
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -725,6 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
|
||||
store_tile(dbias_dram_window, dbiast_shuffle_tmp);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
@@ -807,9 +811,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
block_sync_lds();
|
||||
const auto bias_tile = load_tile(bias_dram_window);
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
shuffle_tile(bias_shuffle_tmp, bias_tile);
|
||||
|
||||
@@ -331,21 +331,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias()
|
||||
{
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (kTotalPixels / kMinVecLoad);
|
||||
return kVecLoad;
|
||||
// TODO: not correct!
|
||||
if constexpr(kTotalPixels > 32)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -617,7 +613,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
|
||||
{
|
||||
return GetAlignmentBias<Problem>();
|
||||
// TODO: this is for 3d layout
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
return 16 / sizeof(BiasDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1682,7 +1680,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t smem_size_stage0_1 = smem_size_v;
|
||||
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
|
||||
smem_size_do + smem_size_lse + smem_size_d +
|
||||
smem_size_ds;
|
||||
max(smem_size_bias, smem_size_ds);
|
||||
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
|
||||
constexpr index_t smem_size_stage3 = smem_size_qt;
|
||||
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
|
||||
|
||||
Reference in New Issue
Block a user