Fix merge bug: add DeviceMoEGemmMXBPreShuffle again (#2816)

This commit is contained in:
Enrico Degregori
2025-09-10 17:03:29 +02:00
committed by GitHub
parent 7ecdba878f
commit bbc8c7d999

View File

@@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
#endif
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef CK_CODE_GEN_RTC
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
#endif
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
///