diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp index 5de7ad59fe..e0cbfba1b7 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // clang-format on #else -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 16; static constexpr bool MulRoutedWeight = true; // clang-format off @@ -163,14 +163,14 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic A0Layout, B0Layout, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ScaleBlockSize, 256, - MPerBlock, 128, 128, + ScaleBlockSize, 64, + MPerBlock, 16, 128, 32, 32, 16, 16, - 8, 2, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 16, 1, 16>, S<2, 1, 1, 1>, + 1, 1, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 1, 1, S<1, 8, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on #endif @@ -408,12 +408,12 @@ int main(int argc, char* argv[]) if(k % 2 == 0) { ck::f4_t f4 = (f4x2 >> 4) & 0xf; - printf("%f ", ck::type_convert(f4)); + printf("%.2f ", ck::type_convert(f4)); } else { ck::f4_t f4 = (f4x2 >> 0) & 0xf; - printf("%f ", ck::type_convert(f4)); + printf("%.2f ", ck::type_convert(f4)); } } printf("\n"); @@ -428,7 +428,7 @@ int main(int argc, char* argv[]) { for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) { - printf("%f ", ck::type_convert(a1_t_k_k(t, tk, k))); + printf("%.2f ", ck::type_convert(a1_t_k_k(t, tk, k))); } printf("\n"); } @@ -446,12 +446,12 @@ int main(int argc, char* argv[]) if(k % 2 == 0) { ck::f4_t f4 = f4x2 >> 4 & 0xf; - printf("%f ", ck::type_convert(f4)); + printf("%.2f ", ck::type_convert(f4)); } else { ck::f4_t f4 = f4x2 >> 0 & 0xf; - printf("%f ", ck::type_convert(f4)); + printf("%.2f ", ck::type_convert(f4)); } } printf("\n"); @@ -466,7 +466,7 @@ int main(int argc, char* argv[]) { for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) { - printf("%f ", ck::type_convert(b1_e_n_k(e, k, n))); + printf("%.2f ", ck::type_convert(b1_e_n_k(e, k, n))); } printf("\n"); } @@ -598,7 +598,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - printf("%f ", ck::type_convert(e_t_n_device_result(t, n))); + printf("%.2f ", ck::type_convert(e_t_n_device_result(t, n))); } printf("\n"); } @@ -608,7 +608,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - printf("%f ", ck::type_convert(e_t_n_host_result(t, n))); + printf("%.2f ", ck::type_convert(e_t_n_host_result(t, n))); } printf("\n"); }