mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix warnings and impl scale for gemm2, build ok
This commit is contained in:
@@ -152,6 +152,9 @@ static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this const
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -181,7 +184,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
1, 1, S<1, 32, 1, 8>, S<EVec, EVec, 1, EVec>,
|
||||
1, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -40,9 +40,8 @@ using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = EDataType;
|
||||
// using DsDataTypeGate = ck::Tuple<D0DataType, D1DataType>;
|
||||
using DsDataTypeUp = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
@@ -51,35 +50,39 @@ using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
|
||||
using DsLayoutUp = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
struct MultiplyMultiply
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
|
||||
//gpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, D2DataType>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>
|
||||
(EDataType& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1,
|
||||
const D2DataType& d2) const
|
||||
const float& d2) const
|
||||
{
|
||||
// const float x0_f = c * d0 * d1;
|
||||
(void)d0; (void)d1; (void)d2;
|
||||
const float x0_f = c;
|
||||
e = ck::type_convert<EDataType>(x0_f);
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
// for reference
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>
|
||||
(float& e,
|
||||
const float& c,
|
||||
const float& d0,
|
||||
const float& d1,
|
||||
const float& d2) const
|
||||
{
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
// using DsLayout = DsLayoutGate;
|
||||
// using DsDataType = DsDataTypeGate;
|
||||
using DsLayout = DsLayoutUp;
|
||||
using DsDataType = DsDataTypeUp;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
@@ -115,7 +118,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyMultiply;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
@@ -126,6 +129,9 @@ static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -155,7 +161,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, EVec, 1, EVec>,
|
||||
CShuffleMXDLPerWave, 1, S<1, 16, 1, 16>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
@@ -232,16 +238,16 @@ int main(int argc, char* argv[])
|
||||
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({SORTED_SIZE, K}, {K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({N, 1}, {1, 0}));
|
||||
Tensor<D1DataType> d1_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
|
||||
Tensor<D2DataType> d2_m_n(HostTensorDescriptor({SORTED_SIZE, N}, {N, 1}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({SORTED_SIZE, N}, {0, 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {0, 0}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({experts, 1}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_m_n: " << d2_m_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
@@ -252,38 +258,38 @@ int main(int argc, char* argv[])
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_m_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
a0_m_k.savetxt("a.txt");
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
d0_device_buf.ToDevice(d0_t_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
d2_device_buf.ToDevice(d2_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_e_n.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
@@ -358,26 +364,26 @@ int main(int argc, char* argv[])
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
|
||||
B0DataType,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
CDEElementOp>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(
|
||||
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, c_t_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
sorted_token_ids, expert_ids, sorted_tile_size, a0_m_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
|
||||
// const int t = sorted_token_ids(m);
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_t_n_host_result(t, n), c_t_n(t, n), d0_t_n(t, n), d1_m_n(t, n), d2_m_n(t, n));
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -535,7 +535,7 @@ struct DeviceMoeGemm
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
// assert(0, "no impl");
|
||||
return std::make_unique<Argument>(nullptr, nullptr,
|
||||
|
||||
@@ -901,9 +901,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
const SrcElementwiseOperation src_element_op_;
|
||||
const DstElementwiseOperation dst_element_op_;
|
||||
StaticallyIndexedArray<index_t, gather_num> gather_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -687,10 +687,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
|
||||
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
|
||||
|
||||
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
const ElementwiseOperation element_op_;
|
||||
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -17,6 +17,9 @@ namespace host {
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename D2DataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
@@ -33,6 +36,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<D0DataType>& d0,
|
||||
const Tensor<D1DataType>& d1,
|
||||
const Tensor<D2DataType>& d2,
|
||||
Tensor<CDataType>& c_t_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
@@ -42,6 +48,9 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
sorted_tile_size_{sorted_tile_size},
|
||||
a_m_k_{a_m_k},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
d0_{d0},
|
||||
d1_{d1},
|
||||
d2_{d2},
|
||||
c_t_n_{c_t_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
@@ -49,16 +58,19 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ck::index_t>& expert_ids_;
|
||||
const Tensor<ck::index_t>& sorted_token_ids_;
|
||||
const Tensor<ck::index_t>& expert_ids_;
|
||||
index_t sorted_tile_size_;
|
||||
const Tensor<ADataType>& a_m_k_;
|
||||
const Tensor<BDataType>& b_e_n_k_;
|
||||
const Tensor<D0DataType>& d0_;
|
||||
const Tensor<D1DataType>& d1_;
|
||||
const Tensor<D2DataType>& d2_;
|
||||
Tensor<CDataType>& c_t_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
index_t sorted_tile_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -106,8 +118,10 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
CDataType v_c{0};
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
D0DataType v_d0 = arg.d0_(m, n); // a
|
||||
D0DataType v_d1 = arg.d1_(e, n); // b
|
||||
D0DataType v_d2 = arg.d2_(e, 0); //expert
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_d2);
|
||||
|
||||
arg.c_t_n_(t, n) += v_c;
|
||||
}
|
||||
@@ -140,12 +154,15 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
const Tensor<D0DataType>& d0,
|
||||
const Tensor<D1DataType>& d1,
|
||||
const Tensor<D2DataType>& d2,
|
||||
Tensor<CDataType>& c_t_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, c_t_n, a_element_op, b_element_op, c_element_op};
|
||||
return Argument{sorted_token_ids, expert_ids, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
Reference in New Issue
Block a user