mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
support bf16 cshuffle
This commit is contained in:
@@ -36,7 +36,7 @@ using A0DataType = F8;
|
||||
using B0DataType = F8;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
@@ -67,6 +67,15 @@ struct MulABScaleExpertWeight
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, EDataType, float, float, float>(
|
||||
EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
@@ -408,7 +417,7 @@ int main(int argc, char* argv[])
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
@@ -416,7 +425,7 @@ int main(int argc, char* argv[])
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
float,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
|
||||
Reference in New Issue
Block a user