From 1bbb50b212e6ce9253f969cda895afc5a625f4e5 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 13 May 2025 20:57:34 -0500 Subject: [PATCH] mfma using asm, device result correct, host result need to check --- .../moe_gemm2_xdl_mx_fp4.cpp | 82 ++++++++++++++++++- ...pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp | 49 +++++++++++ .../gpu/grid/gridwise_moe_mx_gemm.hpp | 2 +- include/ck/utility/amd_xdlops.hpp | 38 +++++++++ 4 files changed, 167 insertions(+), 4 deletions(-) diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 948ddb441f..5d9746b114 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -18,7 +18,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" #include "ck/library/utility/check_err.hpp" - +#include "ck/library/utility/fill.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" template @@ -315,8 +315,12 @@ int main(int argc, char* argv[]) d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 2: - a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + // a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{1.0, 1.0}); + // b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0, 1.0}); + ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}( + a0_t_k_k); + ck::utils::FillConstant{ck::type_convert(ck::float2_t(1.0f))}( + b0_e_n_k); a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); // will to remove @@ -360,6 +364,78 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; +#if 1 + printf("a0_t_k_k:\n"); + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + for(int k = 0; k < K; ++k) + { + if(k % 2 == 0) + { + printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k).data >> 4 & 0xf)); + } + else + { + printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k).data & 0xf)); + } + } + printf("\n"); + } + printf("\n"); + } + + printf("a1_t_k_k:\n"); + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) + { + printf("%f ", ck::type_convert(a1_t_k_k(t, tk, k))); + } + printf("\n"); + } + printf("\n"); + } + + printf("b0_e_n_k:\n"); + for(int e = 0; e < experts; ++e) + { + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + if(k % 2 == 0) + { + printf("%f ", ck::type_convert(b0_e_n_k(e, k, n).data >> 4 & 0xf)); + } + else + { + printf("%f ", ck::type_convert(b0_e_n_k(e, k, n).data & 0xf)); + } + } + printf("\n"); + } + printf("\n"); + } + + printf("b1_e_n_k:\n"); + for(int e = 0; e < experts; ++e) + { + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) + { + printf("%f ", ck::type_convert(b1_e_n_k(e, k, n))); + } + printf("\n"); + } + printf("\n"); + } +#endif + // do GEMM auto device_op = DeviceOpInstance{}; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp index 5276c5278c..713ba1049b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx_tmp.hpp @@ -778,6 +778,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp{}]; }); +#if 0 + printf("bidx: %u, bidy: %u, tidx: %u, a_thread_vec: %02x, %02x, %02x, %02x, b_thread_vec: %02x, %02x, %02x, %02x," + "a_scale_thread_vec: %02x, b_scale_vec: %02x\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast(&(a_thread_vec.template AsType()[Number<3>{}]))), + *(reinterpret_cast(&(b_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&(b_thread_vec.template AsType()[Number<1>{}]))), + *(reinterpret_cast(&(b_thread_vec.template AsType()[Number<2>{}]))), + *(reinterpret_cast(&(b_thread_vec.template AsType()[Number<3>{}]))), + + // type_convert(a_thread_vec.template AsType()[Number<0>{}].unpack(Number<0>{})), + // type_convert(a_thread_vec.template AsType()[Number<0>{}].unpack(Number<1>{})), + // type_convert(a_thread_vec.template AsType()[Number<1>{}].unpack(Number<0>{})), + // type_convert(a_thread_vec.template AsType()[Number<1>{}].unpack(Number<1>{})), + // type_convert(b_thread_vec.template AsType()[Number<0>{}].unpack(Number<0>{})), + // type_convert(b_thread_vec.template AsType()[Number<0>{}].unpack(Number<1>{})), + // type_convert(b_thread_vec.template AsType()[Number<1>{}].unpack(Number<0>{})), + // type_convert(b_thread_vec.template AsType()[Number<1>{}].unpack(Number<1>{})), + *(reinterpret_cast(&(a_scale_thread_vec.template AsType()[Number<0>{}]))), + *(reinterpret_cast(&(b_scale_thread_vec.template AsType()[Number<0>{}])))); +#endif + using mfma_input_type_a = typename vector_type::type; @@ -796,8 +824,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx_tmp(), c_thread_buf.GetVectorTypeReference(Number{})); }); +#if 0 + printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + (c_thread_buf[Number<0>{}]), + (c_thread_buf[Number<1>{}]), + (c_thread_buf[Number<2>{}]), + (c_thread_buf[Number<3>{}])); +#endif }); }); +#if 0 + printf("bidx: %u, bidx: %u, tidx: %u, c_thread_buf: %f, %f, %f, %f\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + type_convert(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType()[Number<0>{}]), + type_convert(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType()[Number<1>{}]), + type_convert(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType()[Number<2>{}]), + type_convert(c_thread_buf.GetVectorTypeReference(Number<0>{}).template AsType()[Number<3>{}])); + +#endif } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 0341428d8f..ad818d7a2b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -1310,7 +1310,7 @@ struct GridwiseMoeGemmMX { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = static_cast(token_offset) * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 85be91a71e..4d75c12052 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -762,6 +762,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> using arg_type = int32x8_t; +#if 1 reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, @@ -773,6 +774,43 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> scale_a, 0, // OPSEL scale_b); +#else + asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 " + "op_sel:[0,0]" + "op_sel_hi:[0,0]" + "cbsz:4" + " blgp:4" + : "+v"(reg_c.template AsType()(Number<0>{})) + : "v"(bit_cast(arg_a)), + "v"(bit_cast(arg_b)), + "v"(reg_c.template AsType()[Number<0>{}]), + "v"(scale_a), + "v"(scale_b)); +#endif + +#if 1 + printf("bidx: %u, bidy: %u, tid: %u, A: %08x, %08x, %08x, %08x," + "B:%08x, %08x, %08x, %08x, a_scale: %08x, b_scale: %08x, " + "reg_c: %f, %f, %f, %f\n", + blockIdx.x, + blockIdx.y, + threadIdx.x, + bit_cast(arg_a[0]), + bit_cast(arg_a[1]), + bit_cast(arg_a[2]), + bit_cast(arg_a[3]), + bit_cast(arg_b[0]), + bit_cast(arg_b[1]), + bit_cast(arg_b[2]), + bit_cast(arg_b[3]), + *(reinterpret_cast(&(scale_a))), + *(reinterpret_cast(&(scale_b))), + reg_c.template AsType()[Number<0>{}], + reg_c.template AsType()[Number<1>{}], + reg_c.template AsType()[Number<2>{}], + reg_c.template AsType()[Number<3>{}]); +#endif + #else ignore = reg_a; ignore = scale_a;