mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Huaiguxu/moe fp8 pertoken scale fix (#2391)
* fix pertoken_scale a_scale dimension
* clang-format
* Fix moe_gemm2_fp8 perTokenScale reference and example.
[ROCm/composable_kernel commit: e1c5172fdb]
This commit is contained in:
@@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool PerTokenQuant = true;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -197,7 +198,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -238,7 +239,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
constexpr auto StrideDs = PerTokenQuant ? std::array<ck::index_t, NumDTensor>{1, 1, 0}
|
||||
: std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
@@ -279,8 +281,10 @@ int main(int argc, char* argv[])
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D0DataType> d0_t_n(
|
||||
HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(
|
||||
HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
@@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
? p_sorted_weights_0[IsInputGemm
|
||||
? token_offset
|
||||
: token_offset *
|
||||
problem.TopK +
|
||||
(fused_token >>
|
||||
24)]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
@@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
? p_sorted_weights_0[IsInputGemm
|
||||
? token_offset
|
||||
: token_offset *
|
||||
problem.TopK +
|
||||
(fused_token >>
|
||||
24)]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
@@ -85,6 +86,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
std::vector<std::mutex> n_locks(arg.c_t_n_.mDesc.GetLengths()[1]);
|
||||
arg.c_t_n_.SetZero();
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
|
||||
@@ -142,8 +144,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
CDataType v_c{0};
|
||||
D0DataType v_d0 = arg.d0_(m, n); // a
|
||||
D0DataType v_d1 = arg.d1_(e, n); // b
|
||||
D0DataType v_d0 = arg.d0_(t, topk_id); // a
|
||||
D0DataType v_d1 = arg.d1_(e, n); // b
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
|
||||
@@ -152,6 +154,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
|
||||
{
|
||||
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(n_locks[n]);
|
||||
arg.c_t_n_(t, n) += v_c;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user