mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
mxfp4 moe gemm1 function passed
This commit is contained in:
@@ -86,38 +86,6 @@ struct MulABScaleExpertWeight
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
I64 tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K_pk; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
@@ -214,8 +182,8 @@ int main(int argc, char* argv[])
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
@@ -224,7 +192,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -344,11 +312,6 @@ int main(int argc, char* argv[])
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
|
||||
// a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
// b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
// a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
// b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
@@ -448,23 +411,19 @@ int main(int argc, char* argv[])
|
||||
printf("a0_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
//for(int tk = 0; tk < topk; ++tk)
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
auto f4x2 = a0_t_k(t, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
auto f4x2 = a0_t_k(t, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -472,13 +431,9 @@ int main(int argc, char* argv[])
|
||||
printf("a1_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
printf("%.2f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
printf("%.2f ", ck::type_convert<float>(a1_t_k(t, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@@ -207,15 +207,16 @@ int main(int argc, char* argv[])
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 13;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -263,8 +264,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size};
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
@@ -340,29 +341,28 @@ int main(int argc, char* argv[])
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
@@ -433,23 +433,19 @@ int main(int argc, char* argv[])
|
||||
printf("a0_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
//for(int tk = 0; tk < topk; ++tk)
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
auto f4x2 = a0_t_k(t, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
auto f4x2 = a0_t_k(t, k).data;
|
||||
if(k % 2 == 0)
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
ck::f4_t f4 = (f4x2 >> 4) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::f4_t f4 = (f4x2 >> 0) & 0xf;
|
||||
printf("%.2f ", ck::type_convert<float>(f4));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -457,13 +453,9 @@ int main(int argc, char* argv[])
|
||||
printf("a1_t_k_k:\n");
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k)
|
||||
{
|
||||
printf("%.2f ", ck::type_convert<float>(a1_t_k_k(t, tk, k)));
|
||||
}
|
||||
printf("\n");
|
||||
printf("%.2f ", ck::type_convert<float>(a1_t_k(t, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -614,8 +606,6 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
|
||||
@@ -164,8 +164,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
|
||||
static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
|
||||
static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2;
|
||||
static constexpr auto async_vmcnt =
|
||||
num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
static constexpr auto async_vmcnt = num_buffer_load_a_scale + num_buffer_load_b_scale +
|
||||
HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
@@ -728,7 +728,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
// HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
@@ -747,7 +747,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
b_blockwise_copy_up.Run(
|
||||
b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I1));
|
||||
|
||||
// Prefetch a_scales
|
||||
// Prefetch a_scales_up
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
@@ -882,7 +882,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value == SwitchM)
|
||||
@@ -1133,7 +1133,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
"0x%08x>, "
|
||||
"b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
|
||||
"a_scale=0x%08x, "
|
||||
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n",
|
||||
"b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>, "
|
||||
"b_thread_vec_up=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, "
|
||||
"b_scale_up=0x%08x, c_thread_buf_up=<%.2f, %.2f, %.2f, %.2f>\n",
|
||||
blockIdx.x,
|
||||
blockIdx.y,
|
||||
threadIdx.x,
|
||||
@@ -1173,6 +1175,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
.template AsType<float>()[Number<2>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<3>{}]),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<0>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<1>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<2>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_thread_vec_up.template AsType<f4x8_t>()[Number<3>{}]))),
|
||||
*(reinterpret_cast<const uint32_t*>(
|
||||
&(b_scale_thread_vec_up
|
||||
.template AsType<BScaleDataType>()[Number<0>{}]))),
|
||||
type_convert<float>(
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<0>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<1>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<2>{}]),
|
||||
type_convert<float>(
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<float>()[Number<3>{}]));
|
||||
#endif
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user