build pass

This commit is contained in:
lalala-sh
2025-07-23 15:38:12 +08:00
parent 7e1bd4b839
commit 3f7d848dd3
3 changed files with 93 additions and 95 deletions

View File

@@ -97,7 +97,7 @@ struct FlatmmScalePointer<-1>
}
};
template <>
template <index_t NumDTensor = 0>
struct BaseFlatmmHostArgs
{
CK_TILE_HOST BaseFlatmmHostArgs() = default;
@@ -169,7 +169,7 @@ struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_, k_batch_),
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
{
@@ -248,7 +248,7 @@ struct FlatmmKernel
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
MakeKernelArgs(const FlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
{
return {hostArgs.a_ptr,
hostArgs.b_ptr,
@@ -754,7 +754,7 @@ struct FlatmmKernel
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<scheduler_type>(a_ptr,
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,