Huaiguxu/moe fp8 pertoken scale fix (#2391)

* fix pertoken_scale a_scale dimension

* clang-format

* Fix moe_gemm2_fp8 perTokenScale reference and example.

[ROCm/composable_kernel commit: e1c5172fdb]
This commit is contained in:
huaiguxu
2025-06-27 10:24:34 +08:00
committed by GitHub
parent c7c24bb10d
commit 0ac91713ae
3 changed files with 26 additions and 9 deletions

View File

@@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool PerTokenQuant = true;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
@@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -197,7 +198,7 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 3)
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
@@ -238,7 +239,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
constexpr auto StrideDs = PerTokenQuant ? std::array<ck::index_t, NumDTensor>{1, 1, 0}
: std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
@@ -279,8 +281,10 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D0DataType> d0_t_n(
HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(
HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));

View File

@@ -1473,7 +1473,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else
@@ -2190,7 +2195,12 @@ struct GridwiseMoeGemm
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
? p_sorted_weights_0[IsInputGemm
? token_offset
: token_offset *
problem.TopK +
(fused_token >>
24)]
: 0.0;
}
else

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include <unordered_map>
#include <mutex>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
@@ -85,6 +86,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
float Run(const Argument& arg)
{
std::vector<std::mutex> n_locks(arg.c_t_n_.mDesc.GetLengths()[1]);
arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_k_.mDesc.GetLengths()[2];
@@ -142,8 +144,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c{0};
D0DataType v_d0 = arg.d0_(m, n); // a
D0DataType v_d1 = arg.d1_(e, n); // b
D0DataType v_d0 = arg.d0_(t, topk_id); // a
D0DataType v_d1 = arg.d1_(e, n); // b
if constexpr(MulRoutedWeight)
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w);
@@ -152,6 +154,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{
arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f);
}
std::lock_guard<std::mutex> lock(n_locks[n]);
arg.c_t_n_(t, n) += v_c;
}
};