[CK_TILE] FMHA BWD Remove Unnecessary Padding (#2550)

* Remove unnecessary pssk

* Add BlockFmhaBwdDQDKDVPipeline wrapper

* Resolve copilot comments & Remove kpad & fix

* Remove spad
This commit is contained in:
Yi DING
2025-08-07 21:24:43 +08:00
committed by GitHub
parent ffdee5e774
commit b0a97498b0
11 changed files with 158 additions and 189 deletions

View File

@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr";
@@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});

View File

@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr_iglp";
@@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
@@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
namespace ck_tile {
template <typename Problem>
class BlockFmhaBwdDQDKDVPipelineSelector
{
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
public:
using type = std::conditional_t<has_dpad,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
};
template <typename Problem>
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
{
public:
static constexpr const char* name = "auto";
};
} // namespace ck_tile

View File

@@ -1,15 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum
{
KRKTRVR_IGLP = 0,
KRKTRVR,
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
};
template <typename ODataType_,