diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp index 8126f76572..30254631ce 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -86,38 +86,6 @@ struct MulABScaleExpertWeight using CDEElementOp = MulABScaleExpertWeight; -// B preshuffle -void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) -{ - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - int K_pk = K / 2; - int K0 = K_pk / (KLane * KPack); - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - I64 tempk; - for(I64 n = 0; n < N; ++n) - { - for(I64 k = 0; k < K_pk; ++k) - { - I64 n0 = n / NLane; - I64 n1 = n % NLane; - - I64 k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - I64 k1 = tempk / KPack; - I64 k2 = tempk % KPack; - - I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - dst[outputIndex] = src[n * K_pk + k]; - } - } -} - // A, B Scale preshuffle template void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) @@ -214,8 +182,8 @@ int main(int argc, char* argv[]) ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t N = 4096; - ck::index_t K = 6144; + ck::index_t N = 6144; + ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t tokens = 832; ck::index_t topk = 2; @@ -224,7 +192,7 @@ int main(int argc, char* argv[]) { // use default case } - else if(argc == 3) + else if(argc == 4) { // use default case do_verification = std::stoi(argv[1]); @@ -344,11 +312,6 @@ int main(int argc, char* argv[]) a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); - - // a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - // b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - // a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - // b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; case 3: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); @@ -448,23 +411,19 @@ int main(int argc, char* argv[]) printf("a0_t_k_k:\n"); for(int t = 0; t < tokens; ++t) { - //for(int tk = 0; tk < topk; ++tk) + for(int k = 0; k < K; ++k) { - for(int k = 0; k < K; ++k) + auto f4x2 = a0_t_k(t, k).data; + if(k % 2 == 0) { - auto f4x2 = a0_t_k(t, k).data; - if(k % 2 == 0) - { - ck::f4_t f4 = (f4x2 >> 4) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } - else - { - ck::f4_t f4 = (f4x2 >> 0) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } + ck::f4_t f4 = (f4x2 >> 4) & 0xf; + printf("%.2f ", ck::type_convert(f4)); + } + else + { + ck::f4_t f4 = (f4x2 >> 0) & 0xf; + printf("%.2f ", ck::type_convert(f4)); } - printf("\n"); } printf("\n"); } @@ -472,13 +431,9 @@ int main(int argc, char* argv[]) printf("a1_t_k_k:\n"); for(int t = 0; t < tokens; ++t) { - for(int tk = 0; tk < topk; ++tk) + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) { - for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) - { - printf("%.2f ", ck::type_convert(a1_t_k_k(t, tk, k))); - } - printf("\n"); + printf("%.2f ", ck::type_convert(a1_t_k(t, k))); } printf("\n"); } diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp index da29c8e972..7eb6a5165f 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -207,15 +207,16 @@ int main(int argc, char* argv[]) // per expert: // GEMM shape - ck::index_t N = 6144; - ck::index_t K = 4096; - ck::index_t experts = 8; - ck::index_t sorted_tile_num = 13; - ck::index_t valid_tile_num = 13; - ck::index_t sorted_size = sorted_tile_num * MPerBlock; - ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 832; - ck::index_t topk = 2; + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; if(argc == 1) { @@ -263,8 +264,8 @@ int main(int argc, char* argv[]) Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); - Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - max_token_id.mData = {valid_size}; + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; if(tokens * topk > valid_size) { @@ -340,29 +341,28 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); - + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); break; case 3: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); break; case 4: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); - d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); break; case 5: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); break; case 6: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); @@ -433,23 +433,19 @@ int main(int argc, char* argv[]) printf("a0_t_k_k:\n"); for(int t = 0; t < tokens; ++t) { - //for(int tk = 0; tk < topk; ++tk) + for(int k = 0; k < K; ++k) { - for(int k = 0; k < K; ++k) + auto f4x2 = a0_t_k(t, k).data; + if(k % 2 == 0) { - auto f4x2 = a0_t_k(t, k).data; - if(k % 2 == 0) - { - ck::f4_t f4 = (f4x2 >> 4) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } - else - { - ck::f4_t f4 = (f4x2 >> 0) & 0xf; - printf("%.2f ", ck::type_convert(f4)); - } + ck::f4_t f4 = (f4x2 >> 4) & 0xf; + printf("%.2f ", ck::type_convert(f4)); + } + else + { + ck::f4_t f4 = (f4x2 >> 0) & 0xf; + printf("%.2f ", ck::type_convert(f4)); } - printf("\n"); } printf("\n"); } @@ -457,13 +453,9 @@ int main(int argc, char* argv[]) printf("a1_t_k_k:\n"); for(int t = 0; t < tokens; ++t) { - for(int tk = 0; tk < topk; ++tk) + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) { - for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) - { - printf("%.2f ", ck::type_convert(a1_t_k_k(t, tk, k))); - } - printf("\n"); + printf("%.2f ", ck::type_convert(a1_t_k(t, k))); } printf("\n"); } @@ -614,8 +606,6 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); using ReferenceGemmInstance = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp index db194b25b7..b15df891eb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -164,8 +164,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2; - static constexpr auto async_vmcnt = - num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt = num_buffer_load_a_scale + num_buffer_load_b_scale + + HotLoopInstList::B_Buffer_Load_Inst_Num * 2; static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; static constexpr auto ScalesPerKBlockSize = @@ -728,7 +728,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< }); }); - HotLoopScheduler(); + // HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); }; @@ -747,7 +747,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< b_blockwise_copy_up.Run( b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I1)); - // Prefetch a_scales + // Prefetch a_scales_up static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { a_scale_thread_copy.Run(a_scale_grid_desc, @@ -882,7 +882,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< a_scale_thread_vec.template AsType(), b_thread_vec_up.template AsType(), b_scale_thread_vec_up.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + c_thread_buf_up.GetVectorTypeReference(Number{})); }); }); if constexpr(m0.value == SwitchM) @@ -1133,7 +1133,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< "0x%08x>, " "b_thread_vec=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, " "a_scale=0x%08x, " - "b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>\n", + "b_scale=0x%08x, c_thread_buf=<%.2f, %.2f, %.2f, %.2f>, " + "b_thread_vec_up=<0x%08x, 0x%08x, 0x%08x, 0x%08x>, " + "b_scale_up=0x%08x, c_thread_buf_up=<%.2f, %.2f, %.2f, %.2f>\n", blockIdx.x, blockIdx.y, threadIdx.x, @@ -1173,6 +1175,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< .template AsType()[Number<2>{}]), type_convert( c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[Number<3>{}]), + *(reinterpret_cast( + &(b_thread_vec_up.template AsType()[Number<0>{}]))), + *(reinterpret_cast( + &(b_thread_vec_up.template AsType()[Number<1>{}]))), + *(reinterpret_cast( + &(b_thread_vec_up.template AsType()[Number<2>{}]))), + *(reinterpret_cast( + &(b_thread_vec_up.template AsType()[Number<3>{}]))), + *(reinterpret_cast( + &(b_scale_thread_vec_up + .template AsType()[Number<0>{}]))), + type_convert( + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[Number<0>{}]), + type_convert( + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[Number<1>{}]), + type_convert( + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[Number<2>{}]), + type_convert( + c_thread_buf_up.GetVectorTypeReference(Number{}) .template AsType()[Number<3>{}])); #endif });