compile error fix

This commit is contained in:
OscarXu
2025-05-15 16:55:20 +08:00
parent 7c76e19e55
commit 68dbe558df
3 changed files with 63 additions and 187 deletions

View File

@@ -115,8 +115,6 @@ struct MulABScaleExpertWeight
}
};
static constexpr bool MulRoutedWeight = true;
using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true
// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true
@@ -163,6 +161,9 @@ using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
#if 0
static constexpr ck::index_t MPerBlock = 128;
@@ -180,8 +181,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -193,12 +193,11 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
#else
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
@@ -213,8 +212,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
#endif
// clang-format on
int main(int argc, char* argv[])
{
@@ -283,22 +282,9 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
@@ -333,7 +319,7 @@ int main(int argc, char* argv[])
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
@@ -398,7 +384,7 @@ int main(int argc, char* argv[])
#if 1
preShuffleBuffer(b0_e_n_k.mData.data(),
b0_preshuffled.mData.data(),
N * experts,
N * 2 * experts,
K,
device_op.GetPreShuffleParameters());
#else

View File

@@ -2293,183 +2293,70 @@ struct GridwiseMoeGemmMX
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// mul scales
const float* p_sorted_weights_0 = p_ds_grid[I0];
const float* p_scale_b = p_ds_grid[I1];
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
static_assert(M4 == 4);
const index_t m1 = get_warp_local_1d_id() / NWave;
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
{
if constexpr(PerTokenQuant)
{
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
}
else
{
p_scale_b += expert_id;
}
vector_type<int32_t, 4> scale_token_ids;
vector_type<float, 4> topk_weights;
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(PerTokenQuant)
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
scale_token_ids =
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
p_sorted_token_ids + m_pos);
}
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
float scale_a = [&]() {
if constexpr(PerTokenQuant)
{
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
const index_t token_offset = fused_token & 0xffffff;
return token_offset < problem.NumTokens
? p_sorted_weights_0[token_offset]
: 0.0;
}
else
{
return p_sorted_weights_0[0];
}
}();
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
const float scale_up =
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
PerTokenQuant];
float gate = scale_a * scale_b * c_thread_buf[cidx];
float up = scale_a * scale_up * c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) =
scale_a * scale_b * c_thread_buf[cidx];
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
topk_weights.AsType<float>()[m4];
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
});
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
}
else
{
vector_type<float, 4> topk_weights; // for gemm2 only
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
p_ds_grid[I2] + m_pos);
}
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, m2 * M4 + m4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[m4];
up = up * topk_weights.AsType<float>()[m4];
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf_fp32(cidx) = gate * up;
}
}
else
{
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
if constexpr(MulRoutedWeight)
{
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
c_thread_buf_fp32[cidx];
}
}
});
});
});
});
}
});
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();

View File

@@ -137,12 +137,15 @@ struct ReferenceMoeMXGemm2 : public device::BaseOperator
f4_t f4 = 0;
f4_t f4_up = 0;
if(k % 2 == 1)
{
f4 = (f4x2 >> 0) & 0xf;
f4_up = (f4x2_up >> 0) & 0xf;
}
else
{
f4 = (f4x2 >> 4) & 0xf;
f4_up = (f4x2_upo >> 4) & 0xf;
f4_up = (f4x2_up >> 4) & 0xf;
}
v_b = type_convert<ComputeTypeB>(f4) *
type_convert<ComputeTypeB>(b_scale);
v_b_up = type_convert<ComputeTypeB>(f4_up) *