mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
build pass
This commit is contained in:
@@ -97,7 +97,7 @@ struct FlatmmScalePointer<-1>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
template <index_t NumDTensor = 0>
|
||||
struct BaseFlatmmHostArgs
|
||||
{
|
||||
CK_TILE_HOST BaseFlatmmHostArgs() = default;
|
||||
@@ -169,7 +169,7 @@ struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
|
||||
index_t stride_C_,
|
||||
ScaleM scale_m_ = nullptr,
|
||||
ScaleN scale_n_ = nullptr)
|
||||
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_, k_batch_),
|
||||
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_),
|
||||
scale_m(scale_m_),
|
||||
scale_n(scale_n_)
|
||||
{
|
||||
@@ -248,7 +248,7 @@ struct FlatmmKernel
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
|
||||
MakeKernelArgs(const FlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
||||
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
|
||||
{
|
||||
return {hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
@@ -754,7 +754,7 @@ struct FlatmmKernel
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<scheduler_type>(a_ptr,
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
|
||||
Reference in New Issue
Block a user