diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 5b00b5a123..30f03cb53b 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -13,7 +13,9 @@ #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" + #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index e04f24c989..fcb12f4a14 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -230,6 +230,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both M and K to be multiples of the block sizes + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MKPadding || @@ -296,6 +313,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return a_grid_desc_ak0_m_ak1; } +#endif } __device__ static auto MakeBGridDescriptor_BK0_N_BK1( @@ -312,6 +330,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both N and K to be multiples of the block sizes + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -378,6 +413,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return b_grid_desc_bk0_n_bk1; } +#endif } template @@ -412,6 +448,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both M and N to be multiples of the block sizes + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -449,6 +492,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 // not pad M or N return c_grid_desc_mraw_nraw; } +#endif } struct Problem @@ -953,7 +997,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) { if(!(karg.M % MPerBlock == 0)) { @@ -970,7 +1015,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) { if(!(karg.N % NPerBlock == 0)) { @@ -1036,6 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1051,6 +1098,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; return false; } } @@ -1065,6 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1082,6 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1098,17 +1151,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; - } - } - if constexpr(is_same, bhalf_t>::value) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp old mode 100644 new mode 100755 diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 2f54be7122..425c2c0391 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -53,11 +53,16 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // AGPR Spill when use permuted lds layout. so, use padding for these two. #if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + #endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on >; @@ -87,6 +92,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; } // namespace instance diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp index a94bb866f2..b0f66a0c73 100644 --- a/profiler/src/profile_gemm_universal_streamk.cpp +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -56,6 +56,26 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) exit(1); } + int M; + int N; + int StrideA; + int StrideB; + // Analyze the unsupported matrix shapes, switch the M and N number + if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0) + { + M = std::stoi(argv[9]); + StrideA = std::stoi(argv[12]); + N = std::stoi(argv[8]); + StrideB = std::stoi(argv[11]); + } + else + { + M = std::stoi(argv[8]); + StrideA = std::stoi(argv[11]); + N = std::stoi(argv[9]); + StrideB = std::stoi(argv[12]); + } + const auto data_type = static_cast(std::stoi(argv[2])); const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); @@ -63,12 +83,8 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) const bool do_log = std::stoi(argv[6]); const bool time_kernel = std::stoi(argv[7]); - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); const int K = std::stoi(argv[10]); - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); const int Streamk_sel = std::stoi(argv[14]); const int Grid_size = std::stoi(argv[15]);