mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] FMHA BWD Remove Unnecessary Padding (#2550)
* Remove unnecessary pssk * Add BlockFmhaBwdDQDKDVPipeline wrapper * Resolve copilot comments & Remove kpad & fix * Remove spad
This commit is contained in:
@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
@@ -554,7 +551,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
@@ -49,8 +49,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
@@ -72,8 +70,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr_iglp";
|
||||
|
||||
@@ -590,7 +587,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
@@ -849,7 +845,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
bool need_perpixel_check = mask.IsEdgeTile(
|
||||
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
|
||||
public:
|
||||
using type = std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<Problem>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<Problem>>;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
class BlockFmhaBwdDQDKDVPipeline : public BlockFmhaBwdDQDKDVPipelineSelector<Problem>::type
|
||||
{
|
||||
public:
|
||||
static constexpr const char* name = "auto";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,15 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockFmhaBwdPipelineEnum
|
||||
{
|
||||
KRKTRVR_IGLP = 0,
|
||||
KRKTRVR,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -55,13 +55,13 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
|
||||
Reference in New Issue
Block a user