fix tail handler bug

This commit is contained in:
lalala-sh
2025-07-17 08:40:35 +00:00
parent fb76450e63
commit 9aa3396a79
4 changed files with 19 additions and 4 deletions

View File

@@ -25,9 +25,19 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
{
if (TailNumber::Even == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
}
else if (TailNumber::Odd == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
}
// assert(false);
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};

View File

@@ -178,11 +178,13 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
VectorSizeA_,
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
@@ -202,7 +204,7 @@ struct FlatmmPipelineProblem
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
@@ -318,7 +320,7 @@ struct FlatmmPipelineProblem
return kPadM ? 1 : GetAlignmentC();
}
}();
};
};
template <typename ADataType_,
typename BDataType_,