mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
qsksvs pipeline changes to mirror qrksvs
This commit is contained in:
@@ -95,6 +95,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
{
|
||||
constexpr std::array occupancy{2, 2, 2, 2, 2, 1};
|
||||
return occupancy[detail::log2<kMaxSplits>::value - 2];
|
||||
} else if constexpr(kHeadDimV <= 512) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -96,6 +96,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 512)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace ck_tile {
|
||||
/// NOTICE: we no-longer use this pipeline.
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
|
||||
struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
struct BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
@@ -81,6 +99,9 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
|
||||
static constexpr const char* name = "qs";
|
||||
|
||||
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
using DropoutType = int32_t; // unused
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
@@ -95,6 +116,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
@@ -106,6 +128,23 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
// operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
// const QElementFunction& q_element_func,
|
||||
// const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
// const KElementFunction& k_element_func,
|
||||
// const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
// const VElementFunction& v_element_func,
|
||||
// const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
// const BiasElementFunction& bias_element_func,
|
||||
// LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
// const LSEElementFunction& lse_element_func,
|
||||
// const SAccElementFunction& s_acc_element_func,
|
||||
// const PComputeElementFunction& p_compute_element_func,
|
||||
// const OAccElementFunction& o_acc_element_func,
|
||||
// FmhaMask mask,
|
||||
// PositionEncoding position_encoding,
|
||||
// float scale_s,
|
||||
// void* smem_ptr) const
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -114,6 +153,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
@@ -122,7 +162,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
|
||||
@@ -127,7 +127,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
|
||||
/// NOTICE: we no-longer use this policy.
|
||||
template <>
|
||||
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
{
|
||||
static constexpr bool QLoadOnce = false;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user