support bf16 cshuffle

This commit is contained in:
coderfeli
2025-03-26 06:50:21 +00:00
parent 6a0cc4aad1
commit cd12af75af

View File

@@ -36,7 +36,7 @@ using A0DataType = F8;
using B0DataType = F8;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
@@ -67,6 +67,15 @@ struct MulABScaleExpertWeight
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
@@ -408,7 +417,7 @@ int main(int argc, char* argv[])
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<CShuffleDataType> c_t_n({tokens, N});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
@@ -416,7 +425,7 @@ int main(int argc, char* argv[])
D0DataType,
D1DataType,
D2DataType,
CShuffleDataType,
float,
AccDataType,
PassThrough,
PassThrough,