diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 3c75a44d21..0c51a58037 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + // this instance has been tested working on gfx950 + // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 4be4e321d3..e5fe92a50d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -124,7 +124,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; + static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index d728360c55..02dba97430 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -519,13 +519,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 50b4a734fa..258d0ad0ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -452,13 +452,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 79a9410898..53a45c7f16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -365,16 +365,20 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr auto mfma = MfmaSelector::selected_mfma; - constexpr auto N3 = mfma.num_groups_per_blk; - constexpr auto N5 = mfma.group_size; + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( d0_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple( @@ -657,16 +661,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_A0K1_B0K1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_A0K1_B0K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index d15767f658..0f2085525f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -347,11 +347,15 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; constexpr auto mfma = - MfmaSelector::selected_mfma; + MfmaSelector:: + selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N4 = mfma.num_input_blks; constexpr auto N5 = mfma.group_size; @@ -564,13 +568,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index a11d696019..33b9199ea5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -473,13 +473,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index ab97a940a8..f406bfb95a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -502,13 +502,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 79ab3acd92..054aca2936 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -679,17 +679,19 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 0e51c6904c..127d889572 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -468,13 +468,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index a3301dd932..be0fff087e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -647,17 +647,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 57b9b02548..7781d1def3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -605,17 +605,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 88d6be234c..5815eb5b0b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -603,13 +603,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 56581256dc..db227bb7ef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -455,13 +455,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 23b4aec3b0..70301c326a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -585,13 +585,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -1018,13 +1024,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index 44c1e936bd..f64838ea4e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -599,13 +599,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index d37b3cd38e..4d3ae93659 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -83,13 +83,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index e5e32a8535..4e72255d31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -144,13 +144,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 240bc464e1..7edcd7270f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -814,13 +814,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index c7d44e842d..f92268265f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -873,13 +873,19 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 29150c0688..0dbbc2a5e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -255,13 +255,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index a22fc06a50..cfa8bfeb2a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -148,13 +148,21 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - using mfma_selector = MfmaSelector; + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = true; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 7124687d5d..93c1779a80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -160,13 +160,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index ac3e821340..97d0e2a4eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -198,13 +198,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 4163d1d01a..38ce9536ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -183,14 +183,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 21812380c2..ef84dd182a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -153,14 +153,20 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index c0d9464136..8fb955c561 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -164,12 +164,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); static constexpr index_t KPackPerGroup = KPack / KGroup; static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; static constexpr index_t NLane = NPerXdl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index b435fd5d5a..67fb4d651e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -493,13 +493,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index ad65e75ef9..50363d832e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -491,13 +491,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 168c553180..b7947309e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -744,14 +744,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && K1 <= 4) || - (is_same::value && K1 <= 8)) + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) ? true : false; - - constexpr index_t KPack = math::max( - K1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t k_pack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t k_pack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 06268f3cfb..b825d7ab69 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1117,12 +1117,31 @@ struct MfmaSelector #endif } + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1136,11 +1155,21 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1166,41 +1195,101 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8f8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8f8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif + } + static constexpr auto selected_mfma = mfma_type::value || - is_same::value) && - KPack <= 4) || - (is_same::value && KPack <= 8)) - ? true - : false, - is_scale_mfma > {}; + // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942- + // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t, + // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + KPack <= 4) || + (is_same::value && KPack <= 8) || + ((is_same::value || is_same::value) && KPack < 32) || + is_same::value) + ? true + : false; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ad48389625..ed3354dfb5 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -533,6 +533,50 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -1118,6 +1162,52 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) {