mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
fix tail handler bug
This commit is contained in:
@@ -25,9 +25,19 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
{
|
||||
if (TailNumber::Even == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if (TailNumber::Odd == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
// assert(false);
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -178,11 +178,13 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
@@ -202,7 +204,7 @@ struct FlatmmPipelineProblem
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
@@ -318,7 +320,7 @@ struct FlatmmPipelineProblem
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
|
||||
Reference in New Issue
Block a user