mxfp4 moe gemm1 function passed

This commit is contained in:
mtgu0705
2025-06-18 09:44:43 -05:00
parent 0743f88625
commit 7516bcd792
3 changed files with 80 additions and 110 deletions

View File

@@ -86,38 +86,6 @@ struct MulABScaleExpertWeight
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
I64 tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K_pk; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
@@ -214,8 +182,8 @@ int main(int argc, char* argv[])
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
@@ -224,7 +192,7 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 3)
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
@@ -344,11 +312,6 @@ int main(int argc, char* argv[])
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
// a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
// a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
// b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
@@ -448,23 +411,19 @@ int main(int argc, char* argv[])
printf("a0_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
//for(int tk = 0; tk < topk; ++tk)
for(int k = 0; k < K; ++k)
{
for(int k = 0; k < K; ++k)
auto f4x2 = a0_t_k(t, k).data;
if(k % 2 == 0)
{
auto f4x2 = a0_t_k(t, k).data;
if(k % 2 == 0)
{
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
printf("\n");
}
printf("\n");
}
@@ -472,13 +431,9 @@ int main(int argc, char* argv[])
printf("a1_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%.2f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
}
printf("\n");
printf("%.2f ", ck::type_convert<float>(a1_t_k(t, k)));
}
printf("\n");
}

View File

@@ -207,15 +207,16 @@ int main(int argc, char* argv[])
// per expert:
// GEMM shape
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 13;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 832;
ck::index_t topk = 2;
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
@@ -263,8 +264,8 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
@@ -340,29 +341,28 @@ int main(int argc, char* argv[])
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
@@ -433,23 +433,19 @@ int main(int argc, char* argv[])
printf("a0_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
//for(int tk = 0; tk < topk; ++tk)
for(int k = 0; k < K; ++k)
{
for(int k = 0; k < K; ++k)
auto f4x2 = a0_t_k(t, k).data;
if(k % 2 == 0)
{
auto f4x2 = a0_t_k(t, k).data;
if(k % 2 == 0)
{
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
else
{
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
printf("%.2f ", ck::type_convert<float>(f4));
}
printf("\n");
}
printf("\n");
}
@@ -457,13 +453,9 @@ int main(int argc, char* argv[])
printf("a1_t_k_k:\n");
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
{
printf("%.2f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
}
printf("\n");
printf("%.2f ", ck::type_convert<float>(a1_t_k(t, k)));
}
printf("\n");
}
@@ -614,8 +606,6 @@ int main(int argc, char* argv[])
{
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =

View File

@@ -164,8 +164,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
static constexpr auto async_vmcnt =
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
static constexpr auto async_vmcnt = num_buffer_load_a_scale + num_buffer_load_b_scale +
HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
static constexpr auto ScalesPerKBlockSize =
@@ -728,7 +728,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
});
});
HotLoopScheduler();
// HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
@@ -747,7 +747,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
b_blockwise_copy_up.Run(
b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I1));
// Prefetch a_scales
// Prefetch a_scales_up
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
@@ -882,7 +882,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == SwitchM)
@@ -1133,7 +1133,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
"0x%08x>, "
"b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
"a_scale=0x%08x, "
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>, "
"b_thread_vec_up=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
"b_scale_up=0x%08x, c_thread_buf_up=<%.2f, %.2f, %.2f, %.2f>\n",
blockIdx.x,
blockIdx.y,
threadIdx.x,
@@ -1173,6 +1175,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
.template AsType<float>()[Number<2>{}]),
type_convert<float>(
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<3>{}]),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<0>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<1>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<2>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<3>{}]))),
*(reinterpret_cast<const uint32_t*>(
&(b_scale_thread_vec_up
.template AsType<BScaleDataType>()[Number<0>{}]))),
type_convert<float>(
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<0>{}]),
type_convert<float>(
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<1>{}]),
type_convert<float>(
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<2>{}]),
type_convert<float>(
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
.template AsType<float>()[Number<3>{}]));
#endif
});