mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
now fwd/bwd can build
This commit is contained in:
@@ -7,8 +7,8 @@
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
@@ -38,7 +38,6 @@
|
||||
#include "ck_tile/core/tensor/slice_tile.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/store_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/tensor/sweep_tile.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
|
||||
@@ -49,13 +48,14 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/unary_element_function.hpp"
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/meta_data_buffer.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
@@ -11,11 +11,11 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
|
||||
@@ -4,11 +4,24 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
@@ -19,19 +32,6 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -56,7 +57,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -91,7 +92,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
_TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (kHasBias ? "_bias" : "") +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -212,7 +214,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
: FmhaBwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaBwdBatchModeBiasKargs, FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaBwdBatchModeBiasKargs,
|
||||
FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>
|
||||
@@ -227,7 +231,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
struct FmhaBwdGroupModeKargs
|
||||
: FmhaBwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaBwdCommonBiasKargs, FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaBwdCommonBiasKargs,
|
||||
FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>
|
||||
@@ -336,7 +342,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_stride_dk,
|
||||
batch_stride_dv};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
@@ -458,7 +464,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
@@ -537,7 +543,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
batch_offset_dk = key_start * kargs.stride_dk;
|
||||
batch_offset_dv = key_start * kargs.stride_dv;
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
@@ -587,7 +593,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
|
||||
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
@@ -919,7 +925,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
constexpr auto bias_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
const BiasDataType* bias_ptr =
|
||||
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
|
||||
|
||||
@@ -34,8 +34,8 @@ struct FmhaFwdKernel
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
|
||||
using RandValOutputDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
|
||||
|
||||
@@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -433,13 +433,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -484,7 +484,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -532,7 +532,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -554,7 +555,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -406,13 +406,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -457,7 +457,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -505,7 +505,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -527,7 +528,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -372,13 +372,13 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -413,7 +413,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -461,7 +461,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -483,7 +484,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -644,7 +644,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias()
|
||||
{
|
||||
constexpr index_t smem_size_bias = [&]() {
|
||||
if constexpr(Problem::kHasBias)
|
||||
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return sizeof(typename Problem::BiasDataType) *
|
||||
MakeBiasTLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
else
|
||||
|
||||
@@ -55,7 +55,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Traits::kHasBias;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
@@ -12,9 +11,13 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
|
||||
Reference in New Issue
Block a user